def test_create_struct_reraises_type_error(self, mock_executor_grpc_stub): instance = mock_executor_grpc_stub.return_value instance.CreateStruct = mock.Mock(side_effect=TypeError) stub = create_stub() with self.assertRaises(TypeError): stub.create_struct(request=executor_pb2.CreateStructRequest())
def test_create_struct_raises_retryable_error_on_grpc_error_unavailable( self, mock_executor_grpc_stub): instance = mock_executor_grpc_stub.return_value instance.CreateStruct = mock.Mock( side_effect=_raise_grpc_error_unavailable) stub = create_stub() with self.assertRaises(executors_errors.RetryableError): stub.create_struct(request=executor_pb2.CreateStructRequest())
def test_create_struct_reraises_grpc_error(self, mock_executor_grpc_stub): instance = mock_executor_grpc_stub.return_value instance.CreateStruct = mock.Mock( side_effect=_raise_non_retryable_grpc_error) stub = create_stub() with self.assertRaises(grpc.RpcError) as context: stub.create_struct(request=executor_pb2.CreateStructRequest()) self.assertEqual(context.exception.code(), grpc.StatusCode.ABORTED)
def test_create_struct_returns_value(self, mock_executor_grpc_stub): response = executor_pb2.CreateStructResponse() instance = mock_executor_grpc_stub.return_value instance.CreateStruct = mock.Mock(side_effect=[response]) stub = create_stub() result = stub.create_struct(request=executor_pb2.CreateStructRequest()) instance.CreateStruct.assert_called_once() self.assertEqual(result, response)
async def create_struct(self, elements): constructed_anon_tuple = structure.from_container(elements) proto_elem = [] type_elem = [] for k, v in structure.iter_elements(constructed_anon_tuple): py_typecheck.check_type(v, RemoteValue) proto_elem.append( executor_pb2.CreateStructRequest.Element( name=(k if k else None), value_ref=v.value_ref)) type_elem.append((k, v.type_signature) if k else v.type_signature) result_type = computation_types.StructType(type_elem) request = executor_pb2.CreateStructRequest(element=proto_elem) response = _request(self._stub.CreateStruct, request) py_typecheck.check_type(response, executor_pb2.CreateStructResponse) return RemoteValue(response.value_ref, result_type, self)
def test_executor_service_create_and_select_from_tuple(self): ex_factory = executor_test_utils.BasicTestExFactory( eager_tf_executor.EagerTFExecutor()) env = TestEnv(ex_factory) value_proto, _ = value_serialization.serialize_value(10, tf.int32) response = env.stub.CreateValue( executor_pb2.CreateValueRequest(executor=env.executor_pb, value=value_proto)) self.assertIsInstance(response, executor_pb2.CreateValueResponse) ten_ref = response.value_ref self.assertEqual(env.get_value(ten_ref.id), 10) value_proto, _ = value_serialization.serialize_value(20, tf.int32) response = env.stub.CreateValue( executor_pb2.CreateValueRequest(executor=env.executor_pb, value=value_proto)) self.assertIsInstance(response, executor_pb2.CreateValueResponse) twenty_ref = response.value_ref self.assertEqual(env.get_value(twenty_ref.id), 20) response = env.stub.CreateStruct( executor_pb2.CreateStructRequest( executor=env.executor_pb, element=[ executor_pb2.CreateStructRequest.Element( name='a', value_ref=ten_ref), executor_pb2.CreateStructRequest.Element( name='b', value_ref=twenty_ref) ])) self.assertIsInstance(response, executor_pb2.CreateStructResponse) tuple_ref = response.value_ref self.assertEqual(str(env.get_value(tuple_ref.id)), '<a=10,b=20>') for index, result_val in [(0, 10), (1, 20)]: response = env.stub.CreateSelection( executor_pb2.CreateSelectionRequest(executor=env.executor_pb, source_ref=tuple_ref, index=index)) self.assertIsInstance(response, executor_pb2.CreateSelectionResponse) selection_ref = response.value_ref self.assertEqual(env.get_value(selection_ref.id), result_val) del env
def test_executor_service_create_and_select_from_tuple(self): ex_factory = executor_stacks.ResourceManagingExecutorFactory( lambda _: eager_tf_executor.EagerTFExecutor()) env = TestEnv(ex_factory) value_proto, _ = executor_serialization.serialize_value(10, tf.int32) response = env.stub.CreateValue( executor_pb2.CreateValueRequest(value=value_proto)) self.assertIsInstance(response, executor_pb2.CreateValueResponse) ten_ref = response.value_ref self.assertEqual(env.get_value(ten_ref.id), 10) value_proto, _ = executor_serialization.serialize_value(20, tf.int32) response = env.stub.CreateValue( executor_pb2.CreateValueRequest(value=value_proto)) self.assertIsInstance(response, executor_pb2.CreateValueResponse) twenty_ref = response.value_ref self.assertEqual(env.get_value(twenty_ref.id), 20) response = env.stub.CreateStruct( executor_pb2.CreateStructRequest(element=[ executor_pb2.CreateStructRequest.Element(name='a', value_ref=ten_ref), executor_pb2.CreateStructRequest.Element(name='b', value_ref=twenty_ref) ])) self.assertIsInstance(response, executor_pb2.CreateStructResponse) tuple_ref = response.value_ref self.assertEqual(str(env.get_value(tuple_ref.id)), '<a=10,b=20>') for arg_name, arg_val, result_val in [('name', 'a', 10), ('name', 'b', 20), ('index', 0, 10), ('index', 1, 20)]: response = env.stub.CreateSelection( executor_pb2.CreateSelectionRequest(source_ref=tuple_ref, **{arg_name: arg_val})) self.assertIsInstance(response, executor_pb2.CreateSelectionResponse) selection_ref = response.value_ref self.assertEqual(env.get_value(selection_ref.id), result_val) del env
async def create_struct(self, elements): constructed_anon_tuple = anonymous_tuple.from_container(elements) proto_elem = [] type_elem = [] for k, v in anonymous_tuple.iter_elements(constructed_anon_tuple): py_typecheck.check_type(v, RemoteValue) proto_elem.append( executor_pb2.CreateStructRequest.Element( name=(k if k else None), value_ref=v.value_ref)) type_elem.append((k, v.type_signature) if k else v.type_signature) result_type = computation_types.NamedTupleType(type_elem) request = executor_pb2.CreateStructRequest(element=proto_elem) if self._bidi_stream is None: response = _request(self._stub.CreateStruct, request) else: response = (await self._bidi_stream.send_request( executor_pb2.ExecuteRequest(create_struct=request))).create_struct py_typecheck.check_type(response, executor_pb2.CreateStructResponse) return RemoteValue(response.value_ref, result_type, self)