def test_create_selection_reraises_type_error(self, mock_stub): instance = mock_stub.return_value instance.CreateSelection = mock.Mock(side_effect=TypeError) loop = asyncio.get_event_loop() executor = create_remote_executor() type_signature = computation_types.StructType([tf.int32, tf.int32]) source = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) with self.assertRaises(TypeError): loop.run_until_complete(executor.create_selection(source, 0))
def test_create_call_reraises_type_error(self, mock_stub): instance = mock_stub.return_value instance.CreateCall = mock.Mock(side_effect=TypeError) loop = asyncio.get_event_loop() executor = create_remote_executor() type_signature = computation_types.FunctionType(None, tf.int32) comp = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) with self.assertRaises(TypeError): loop.run_until_complete(executor.create_call(comp))
def test_create_call_returns_remote_value(self, mock_stub): response = executor_pb2.ExecuteResponse( create_call=executor_pb2.CreateCallResponse()) executor = _setup_mock_streaming_executor(mock_stub, response) loop = asyncio.get_event_loop() type_signature = computation_types.FunctionType(None, tf.int32) fn = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) result = loop.run_until_complete(executor.create_call(fn, None)) self.assertIsInstance(result, remote_executor.RemoteValue)
def test_compute_raises_retryable_error_on_grpc_error_unavailable( self, mock_stub): instance = mock_stub.return_value instance.Compute = mock.Mock(side_effect=_raise_grpc_error_unavailable) loop = asyncio.get_event_loop() executor = create_remote_executor() type_signature = computation_types.FunctionType(None, tf.int32) comp = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) with self.assertRaises(execution_context.RetryableError): loop.run_until_complete(comp.compute())
def test_create_call_returns_remote_value(self, mock_stub): mock_stub.create_call.return_value = executor_pb2.CreateCallResponse() executor = remote_executor.RemoteExecutor(mock_stub) _set_cardinalities_with_mock(executor, mock_stub) type_signature = computation_types.FunctionType(None, tf.int32) fn = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) result = asyncio.run(executor.create_call(fn, None)) mock_stub.create_call.assert_called_once() self.assertIsInstance(result, remote_executor.RemoteValue)
def test_create_call_reraises_grpc_error(self, mock_stub): mock_stub.create_call = mock.Mock( side_effect=_raise_non_retryable_grpc_error) executor = remote_executor.RemoteExecutor(mock_stub) _set_cardinalities_with_mock(executor, mock_stub) type_signature = computation_types.FunctionType(None, tf.int32) comp = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) with self.assertRaises(grpc.RpcError) as context: asyncio.run(executor.create_call(comp, None)) self.assertEqual(context.exception.code(), grpc.StatusCode.ABORTED)
def test_create_selection_returns_remote_value(self, mock_stub): mock_stub.create_selection.return_value = executor_pb2.CreateSelectionResponse( ) executor = remote_executor.RemoteExecutor(mock_stub) _set_cardinalities_with_mock(executor, mock_stub) type_signature = computation_types.StructType([tf.int32, tf.int32]) source = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) result = asyncio.run(executor.create_selection(source, 0)) mock_stub.create_selection.assert_called_once() self.assertIsInstance(result, remote_executor.RemoteValue)
def test_create_selection_returns_remote_value(self, mock_stub): response = executor_pb2.ExecuteResponse( create_selection=executor_pb2.CreateSelectionResponse()) executor = _setup_mock_streaming_executor(mock_stub, response) loop = asyncio.get_event_loop() type_signature = computation_types.StructType([tf.int32, tf.int32]) source = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) result = loop.run_until_complete( executor.create_selection(source, index=0)) self.assertIsInstance(result, remote_executor.RemoteValue)
def test_create_selection_raises_retryable_error_on_grpc_error_unavailable( self, mock_stub): instance = mock_stub.return_value instance.CreateSelection = mock.Mock( side_effect=_raise_grpc_error_unavailable) loop = asyncio.get_event_loop() executor = create_remote_executor() type_signature = computation_types.NamedTupleType([tf.int32, tf.int32]) source = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) with self.assertRaises(execution_context.RetryableError): loop.run_until_complete(executor.create_selection(source, index=0))
def test_create_call_reraises_grpc_error(self, mock_stub): instance = mock_stub.return_value instance.CreateCall = mock.Mock(side_effect=_raise_non_retryable_grpc_error) loop = asyncio.get_event_loop() executor = create_remote_executor() type_signature = computation_types.FunctionType(None, tf.int32) comp = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) with self.assertRaises(grpc.RpcError) as context: loop.run_until_complete(executor.create_call(comp, None)) self.assertEqual(context.exception.code(), grpc.StatusCode.ABORTED)
def test_create_call_returns_remote_value(self, mock_stub): response = executor_pb2.CreateCallResponse() instance = mock_stub.return_value instance.CreateCall = mock.Mock(side_effect=[response]) loop = asyncio.get_event_loop() executor = create_remote_executor() type_signature = computation_types.FunctionType(None, tf.int32) fn = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) result = loop.run_until_complete(executor.create_call(fn, None)) instance.CreateCall.assert_called_once() self.assertIsInstance(result, remote_executor.RemoteValue)
def test_create_selection_reraises_non_retryable_grpc_error(self, mock_stub): instance = mock_stub.return_value instance.CreateSelection = mock.Mock( side_effect=_raise_non_retryable_grpc_error) loop = asyncio.get_event_loop() executor = create_remote_executor() type_signature = computation_types.NamedTupleType([tf.int32, tf.int32]) source = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) with self.assertRaises(grpc.RpcError) as context: loop.run_until_complete(executor.create_selection(source, index=0)) self.assertEqual(context.exception.code(), grpc.StatusCode.ABORTED)
def test_create_selection_returns_remote_value(self, mock_stub): response = executor_pb2.CreateSelectionResponse() instance = mock_stub.return_value instance.CreateSelection = mock.Mock(side_effect=[response]) loop = asyncio.get_event_loop() executor = create_remote_executor() type_signature = computation_types.NamedTupleType([tf.int32, tf.int32]) source = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) result = loop.run_until_complete(executor.create_selection(source, index=0)) instance.CreateSelection.assert_called_once() self.assertIsInstance(result, remote_executor.RemoteValue)
def test_compute_reraises_grpc_error(self, grpc_stub): instance = grpc_stub.return_value instance.Compute = mock.Mock( side_effect=_raise_non_retryable_grpc_error) stub = create_stub() request = executor_pb2.ComputeRequest( executor=executor_pb2.ExecutorId(), value_ref=executor_pb2.ValueRef()) with self.assertRaises(grpc.RpcError) as context: stub.compute(request) self.assertEqual(context.exception.code(), grpc.StatusCode.ABORTED)
def test_compute_reraises_grpc_error_deadline_exceeded(self, mock_stub): instance = mock_stub.return_value instance.Compute = mock.Mock( side_effect=_raise_grpc_error_deadline_exceeded) loop = asyncio.get_event_loop() executor = create_remote_executor() type_signature = computation_types.FunctionType(None, tf.int32) comp = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) with self.assertRaises(grpc.RpcError) as context: loop.run_until_complete(comp.compute()) self.assertEqual(context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
def test_compute_returns_result(self, mock_stub): tensor_proto = tf.make_tensor_proto(1) any_pb = any_pb2.Any() any_pb.Pack(tensor_proto) value = executor_pb2.Value(tensor=any_pb) response = executor_pb2.ComputeResponse(value=value) instance = mock_stub.return_value instance.Compute = mock.Mock(side_effect=[response]) loop = asyncio.get_event_loop() executor = create_remote_executor() type_signature = computation_types.FunctionType(None, tf.int32) comp = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) result = loop.run_until_complete(comp.compute()) instance.Compute.assert_called_once() self.assertEqual(result, 1)
def test_compute_returns_result(self, mock_stub): tensor_proto = tf.make_tensor_proto(1) any_pb = any_pb2.Any() any_pb.Pack(tensor_proto) value = executor_pb2.Value(tensor=any_pb) mock_stub.compute.return_value = executor_pb2.ComputeResponse( value=value) executor = remote_executor.RemoteExecutor(mock_stub) _set_cardinalities_with_mock(executor, mock_stub) executor.set_cardinalities({placements.CLIENTS: 3}) type_signature = computation_types.FunctionType(None, tf.int32) comp = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) result = asyncio.run(comp.compute()) mock_stub.compute.assert_called_once() self.assertEqual(result, 1)
def test_compute_returns_result(self, mock_executor_grpc_stub): tensor_proto = tf.make_tensor_proto(1) any_pb = any_pb2.Any() any_pb.Pack(tensor_proto) value = executor_pb2.Value(tensor=any_pb) response = executor_pb2.ComputeResponse(value=value) instance = mock_executor_grpc_stub.return_value instance.Compute = mock.Mock(side_effect=[response]) request = executor_pb2.ComputeRequest( executor=executor_pb2.ExecutorId(), value_ref=executor_pb2.ValueRef()) stub = create_stub() result = stub.compute(request) instance.Compute.assert_called_once() value, _ = value_serialization.deserialize_value(result.value) self.assertEqual(value, 1)
def CreateValue( self, request: executor_pb2.CreateValueRequest, context: grpc.ServicerContext, ) -> executor_pb2.CreateValueResponse: """Creates a value embedded in the executor.""" py_typecheck.check_type(request, executor_pb2.CreateValueRequest) try: value, value_type = ( executor_service_utils.deserialize_value(request.value)) value_id = str(uuid.uuid4()) future_val = asyncio.run_coroutine_threadsafe( self._executor.create_value(value, value_type), self._event_loop) with self._lock: self._values[value_id] = future_val return executor_pb2.CreateValueResponse( value_ref=executor_pb2.ValueRef(id=value_id)) except (ValueError, TypeError) as err: _set_invalid_arg_err(context, err) return executor_pb2.CreateValueResponse()
def CreateValue( self, request: executor_pb2.CreateValueRequest, context: grpc.ServicerContext, ) -> executor_pb2.CreateValueResponse: """Creates a value embedded in the executor.""" py_typecheck.check_type(request, executor_pb2.CreateValueRequest) with self._try_handle_request_context( request, context, executor_pb2.CreateValueResponse): with tracing.span('ExecutorService.CreateValue', 'deserialize_value'): value, value_type = (value_serialization.deserialize_value( request.value)) value_id = str(uuid.uuid4()) coro = self.executor(request, context).create_value(value, value_type) future_val = self._run_coro_threadsafe_with_tracing(coro) with self._lock: self._values[value_id] = future_val return executor_pb2.CreateValueResponse( value_ref=executor_pb2.ValueRef(id=value_id))
def CreateValue( self, request: executor_pb2.CreateValueRequest, context: grpc.ServicerContext, ) -> executor_pb2.CreateValueResponse: """Creates a value embedded in the executor.""" py_typecheck.check_type(request, executor_pb2.CreateValueRequest) try: with tracing.span('ExecutorService.CreateValue', 'deserialize_value'): value, value_type = ( executor_service_utils.deserialize_value(request.value)) value_id = str(uuid.uuid4()) coro = self._executor.create_value(value, value_type) future_val = self._run_coro_threadsafe_with_tracing(coro) with self._lock: self._values[value_id] = future_val return executor_pb2.CreateValueResponse( value_ref=executor_pb2.ValueRef(id=value_id)) except (ValueError, TypeError) as err: _set_invalid_arg_err(context, err) return executor_pb2.CreateValueResponse()
def CreateTuple(self, request, context): """Creates a tuple embedded in the executor. Args: request: An instance of `executor_pb2.CreateTupleRequest`. context: An instance of `grpc.ServicerContext`. Returns: An instance of `executor_pb2.CreateTupleResponse`. """ py_typecheck.check_type(request, executor_pb2.CreateTupleRequest) try: with self._lock: elem_futures = [ self._values[e.value_ref.id] for e in request.element ] elem_names = [ str(elem.name) if elem.name else None for elem in request.element ] async def _processing(): elem_values = await asyncio.gather( *[asyncio.wrap_future(v) for v in elem_futures]) elements = list(zip(elem_names, elem_values)) anon_tuple = anonymous_tuple.AnonymousTuple(elements) return await self._executor.create_tuple(anon_tuple) result_fut = self._event_loop.create_task(_processing()) result_id = str(uuid.uuid4()) with self._lock: self._values[result_id] = result_fut return executor_pb2.CreateTupleResponse( value_ref=executor_pb2.ValueRef(id=result_id)) except (ValueError, TypeError) as err: logging.error(traceback.format_exc()) context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_details(str(err)) return executor_pb2.CreateTupleResponse()
def CreateValue(self, request, context): """Creates a value embedded in the executor. Args: request: An instance of `executor_pb2.CreateValueRequest`. context: An instance of `grpc.ServicerContext`. Returns: An instance of `executor_pb2.CreateValueResponse`. """ py_typecheck.check_type(request, executor_pb2.CreateValueRequest) try: value, value_type = (executor_service_utils.deserialize_value( request.value)) value_id = str(uuid.uuid4()) future_val = self._executor.create_value(value, value_type) with self._lock: self._values[value_id] = future_val return executor_pb2.CreateValueResponse( value_ref=executor_pb2.ValueRef(id=value_id)) except (ValueError, TypeError): context.set_code(grpc.StatusCode.INVALID_ARGUMENT) return executor_pb2.CreateValueResponse()
def CreateSelection(self, request, context): """Creates a selection embedded in the executor. Args: request: An instance of `executor_pb2.CreateSelectionRequest`. context: An instance of `grpc.ServicerContext`. Returns: An instance of `executor_pb2.CreateSelectionResponse`. """ py_typecheck.check_type(request, executor_pb2.CreateSelectionRequest) try: with self._lock: source_fut = self._values[request.source_ref.id] async def _processing(): source = await asyncio.wrap_future(source_fut) which_selection = request.WhichOneof('selection') if which_selection == 'name': coro = self._executor.create_selection(source, name=request.name) else: coro = self._executor.create_selection(source, index=request.index) return await coro result_fut = self._event_loop.create_task(_processing()) result_id = str(uuid.uuid4()) with self._lock: self._values[result_id] = result_fut return executor_pb2.CreateSelectionResponse( value_ref=executor_pb2.ValueRef(id=result_id)) except (ValueError, TypeError) as err: logging.error(traceback.format_exc()) context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_details(str(err)) return executor_pb2.CreateSelectionResponse()
def CreateCall(self, request, context): """Creates a call embedded in the executor. Args: request: An instance of `executor_pb2.CreateCallRequest`. context: An instance of `grpc.ServicerContext`. Returns: An instance of `executor_pb2.CreateCallResponse`. """ py_typecheck.check_type(request, executor_pb2.CreateCallRequest) try: function_id = str(request.function_ref.id) argument_id = str(request.argument_ref.id) with self._lock: function_val = self._values[function_id] argument_val = self._values[ argument_id] if argument_id else None async def _processing(): function = await asyncio.wrap_future(function_val) argument = await asyncio.wrap_future( argument_val) if argument_val is not None else None return await self._executor.create_call(function, argument) result_fut = self._event_loop.create_task(_processing()) result_id = str(uuid.uuid4()) with self._lock: self._values[result_id] = result_fut return executor_pb2.CreateCallResponse( value_ref=executor_pb2.ValueRef(id=result_id)) except (ValueError, TypeError) as err: logging.error(traceback.format_exc()) context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_details(str(err)) return executor_pb2.CreateCallResponse()
def CreateSelection( self, request: executor_pb2.CreateSelectionRequest, context: grpc.ServicerContext, ) -> executor_pb2.CreateSelectionResponse: """Creates a selection embedded in the executor.""" py_typecheck.check_type(request, executor_pb2.CreateSelectionRequest) with self._try_handle_request_context( request, context, executor_pb2.CreateSelectionResponse): with self._lock: source_fut = self._values[request.source_ref.id] async def _process_create_selection(): source = await asyncio.wrap_future(source_fut) return await self.executor(request, context).create_selection( source, request.index) result_fut = self._run_coro_threadsafe_with_tracing( _process_create_selection()) result_id = str(uuid.uuid4()) with self._lock: self._values[result_id] = result_fut return executor_pb2.CreateSelectionResponse( value_ref=executor_pb2.ValueRef(id=result_id))