def test_create_selection_reraises_type_error(self,
                                                  mock_executor_grpc_stub):
        instance = mock_executor_grpc_stub.return_value
        instance.CreateSelection = mock.Mock(side_effect=TypeError)
        stub = create_stub()

        with self.assertRaises(TypeError):
            stub.create_selection(
                request=executor_pb2.CreateSelectionRequest())
示例#2
0
 async def create_selection(self, source, index):
   py_typecheck.check_type(source, RemoteValue)
   py_typecheck.check_type(source.type_signature, computation_types.StructType)
   py_typecheck.check_type(index, int)
   result_type = source.type_signature[index]
   request = executor_pb2.CreateSelectionRequest(
       source_ref=source.value_ref, index=index)
   response = _request(self._stub.CreateSelection, request)
   py_typecheck.check_type(response, executor_pb2.CreateSelectionResponse)
   return RemoteValue(response.value_ref, result_type, self)
    def test_create_selection_raises_retryable_error_on_grpc_error_unavailable(
            self, mock_executor_grpc_stub):
        instance = mock_executor_grpc_stub.return_value
        instance.CreateSelection = mock.Mock(
            side_effect=_raise_grpc_error_unavailable)
        stub = create_stub()

        with self.assertRaises(executors_errors.RetryableError):
            stub.create_selection(
                request=executor_pb2.CreateSelectionRequest())
    def test_create_selection_reraises_non_retryable_grpc_error(
            self, mock_executor_grpc_stub):
        instance = mock_executor_grpc_stub.return_value
        instance.CreateSelection = mock.Mock(
            side_effect=_raise_non_retryable_grpc_error)
        stub = create_stub()
        with self.assertRaises(grpc.RpcError) as context:
            stub.create_selection(
                request=executor_pb2.CreateSelectionRequest())

        self.assertEqual(context.exception.code(), grpc.StatusCode.ABORTED)
    def test_create_selection_returns_value(self, mock_executor_grpc_stub):
        response = executor_pb2.CreateSelectionResponse()
        instance = mock_executor_grpc_stub.return_value
        instance.CreateSelection = mock.Mock(side_effect=[response])
        stub = create_stub()

        result = stub.create_selection(
            request=executor_pb2.CreateSelectionRequest())

        instance.CreateSelection.assert_called_once()
        self.assertEqual(result, response)
 async def create_selection(self, source, index=None, name=None):
     py_typecheck.check_type(source, RemoteValue)
     py_typecheck.check_type(source.type_signature,
                             computation_types.StructType)
     if index is None:
         py_typecheck.check_type(name, str)
         index = structure.name_to_index_map(source.type_signature)[name]
     py_typecheck.check_type(index, int)
     result_type = source.type_signature[index]
     request = executor_pb2.CreateSelectionRequest(
         source_ref=source.value_ref, index=index)
     response = _request(self._stub.CreateSelection, request)
     py_typecheck.check_type(response, executor_pb2.CreateSelectionResponse)
     return RemoteValue(response.value_ref, result_type, self)
示例#7
0
 async def create_selection(self, source, index=None, name=None):
     py_typecheck.check_type(source, RemoteValue)
     py_typecheck.check_type(source.type_signature,
                             computation_types.NamedTupleType)
     if index is not None:
         py_typecheck.check_type(index, int)
         py_typecheck.check_none(name)
         result_type = source.type_signature[index]
     else:
         py_typecheck.check_type(name, str)
         result_type = getattr(source.type_signature, name)
     response = self._stub.CreateSelection(
         executor_pb2.CreateSelectionRequest(source_ref=source.value_ref,
                                             name=name,
                                             index=index))
     py_typecheck.check_type(response, executor_pb2.CreateSelectionResponse)
     return RemoteValue(response.value_ref, result_type, self)
    def test_executor_service_create_and_select_from_tuple(self):
        ex_factory = executor_test_utils.BasicTestExFactory(
            eager_tf_executor.EagerTFExecutor())
        env = TestEnv(ex_factory)

        value_proto, _ = value_serialization.serialize_value(10, tf.int32)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(executor=env.executor_pb,
                                            value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        ten_ref = response.value_ref
        self.assertEqual(env.get_value(ten_ref.id), 10)

        value_proto, _ = value_serialization.serialize_value(20, tf.int32)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(executor=env.executor_pb,
                                            value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        twenty_ref = response.value_ref
        self.assertEqual(env.get_value(twenty_ref.id), 20)

        response = env.stub.CreateStruct(
            executor_pb2.CreateStructRequest(
                executor=env.executor_pb,
                element=[
                    executor_pb2.CreateStructRequest.Element(
                        name='a', value_ref=ten_ref),
                    executor_pb2.CreateStructRequest.Element(
                        name='b', value_ref=twenty_ref)
                ]))
        self.assertIsInstance(response, executor_pb2.CreateStructResponse)
        tuple_ref = response.value_ref
        self.assertEqual(str(env.get_value(tuple_ref.id)), '<a=10,b=20>')

        for index, result_val in [(0, 10), (1, 20)]:
            response = env.stub.CreateSelection(
                executor_pb2.CreateSelectionRequest(executor=env.executor_pb,
                                                    source_ref=tuple_ref,
                                                    index=index))
            self.assertIsInstance(response,
                                  executor_pb2.CreateSelectionResponse)
            selection_ref = response.value_ref
            self.assertEqual(env.get_value(selection_ref.id), result_val)

        del env
示例#9
0
    def test_executor_service_create_and_select_from_tuple(self):
        ex_factory = executor_stacks.ResourceManagingExecutorFactory(
            lambda _: eager_tf_executor.EagerTFExecutor())
        env = TestEnv(ex_factory)

        value_proto, _ = executor_serialization.serialize_value(10, tf.int32)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        ten_ref = response.value_ref
        self.assertEqual(env.get_value(ten_ref.id), 10)

        value_proto, _ = executor_serialization.serialize_value(20, tf.int32)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        twenty_ref = response.value_ref
        self.assertEqual(env.get_value(twenty_ref.id), 20)

        response = env.stub.CreateStruct(
            executor_pb2.CreateStructRequest(element=[
                executor_pb2.CreateStructRequest.Element(name='a',
                                                         value_ref=ten_ref),
                executor_pb2.CreateStructRequest.Element(name='b',
                                                         value_ref=twenty_ref)
            ]))
        self.assertIsInstance(response, executor_pb2.CreateStructResponse)
        tuple_ref = response.value_ref
        self.assertEqual(str(env.get_value(tuple_ref.id)), '<a=10,b=20>')

        for arg_name, arg_val, result_val in [('name', 'a', 10),
                                              ('name', 'b', 20),
                                              ('index', 0, 10),
                                              ('index', 1, 20)]:
            response = env.stub.CreateSelection(
                executor_pb2.CreateSelectionRequest(source_ref=tuple_ref,
                                                    **{arg_name: arg_val}))
            self.assertIsInstance(response,
                                  executor_pb2.CreateSelectionResponse)
            selection_ref = response.value_ref
            self.assertEqual(env.get_value(selection_ref.id), result_val)

        del env
 async def create_selection(self, source, index=None, name=None):
   py_typecheck.check_type(source, RemoteValue)
   py_typecheck.check_type(source.type_signature, computation_types.StructType)
   if index is not None:
     py_typecheck.check_type(index, int)
     py_typecheck.check_none(name)
     result_type = source.type_signature[index]
   else:
     py_typecheck.check_type(name, str)
     result_type = getattr(source.type_signature, name)
   request = executor_pb2.CreateSelectionRequest(
       source_ref=source.value_ref, name=name, index=index)
   if self._bidi_stream is None:
     response = _request(self._stub.CreateSelection, request)
   else:
     response = (await self._bidi_stream.send_request(
         executor_pb2.ExecuteRequest(create_selection=request)
     )).create_selection
   py_typecheck.check_type(response, executor_pb2.CreateSelectionResponse)
   return RemoteValue(response.value_ref, result_type, self)
示例#11
0
 async def create_selection(self, source, index=None, name=None):
   py_typecheck.check_type(source, RemoteValue)
   py_typecheck.check_type(source.type_signature,
                           computation_types.NamedTupleType)
   if index is not None:
     py_typecheck.check_type(index, int)
     py_typecheck.check_none(name)
     result_type = source.type_signature[index]
   else:
     py_typecheck.check_type(name, str)
     result_type = getattr(source.type_signature, name)
   request = executor_pb2.CreateSelectionRequest(
       source_ref=source.value_ref, name=name, index=index)
   if not self._bidi_stream:
     try:
       response = self._stub.CreateSelection(request)
     except grpc.RpcError as e:
       self._handle_grpc_error(e)
   else:
     response = (await self._bidi_stream.send_request(
         executor_pb2.ExecuteRequest(create_selection=request)
     )).create_selection
   py_typecheck.check_type(response, executor_pb2.CreateSelectionResponse)
   return RemoteValue(response.value_ref, result_type, self)