Exemple #1
0
  def test_create_selection_reraises_type_error(self, mock_stub):
    instance = mock_stub.return_value
    instance.CreateSelection = mock.Mock(side_effect=TypeError)
    loop = asyncio.get_event_loop()
    executor = create_remote_executor()
    type_signature = computation_types.StructType([tf.int32, tf.int32])
    source = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                         type_signature, executor)

    with self.assertRaises(TypeError):
      loop.run_until_complete(executor.create_selection(source, 0))
Exemple #2
0
  def test_create_call_reraises_type_error(self, mock_stub):
    instance = mock_stub.return_value
    instance.CreateCall = mock.Mock(side_effect=TypeError)
    loop = asyncio.get_event_loop()
    executor = create_remote_executor()
    type_signature = computation_types.FunctionType(None, tf.int32)
    comp = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature,
                                       executor)

    with self.assertRaises(TypeError):
      loop.run_until_complete(executor.create_call(comp))
    def test_create_call_returns_remote_value(self, mock_stub):
        response = executor_pb2.ExecuteResponse(
            create_call=executor_pb2.CreateCallResponse())
        executor = _setup_mock_streaming_executor(mock_stub, response)
        loop = asyncio.get_event_loop()
        type_signature = computation_types.FunctionType(None, tf.int32)
        fn = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                         type_signature, executor)

        result = loop.run_until_complete(executor.create_call(fn, None))

        self.assertIsInstance(result, remote_executor.RemoteValue)
Exemple #4
0
    def test_compute_raises_retryable_error_on_grpc_error_unavailable(
            self, mock_stub):
        instance = mock_stub.return_value
        instance.Compute = mock.Mock(side_effect=_raise_grpc_error_unavailable)
        loop = asyncio.get_event_loop()
        executor = create_remote_executor()
        type_signature = computation_types.FunctionType(None, tf.int32)
        comp = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                           type_signature, executor)

        with self.assertRaises(execution_context.RetryableError):
            loop.run_until_complete(comp.compute())
    def test_create_call_returns_remote_value(self, mock_stub):
        mock_stub.create_call.return_value = executor_pb2.CreateCallResponse()
        executor = remote_executor.RemoteExecutor(mock_stub)
        _set_cardinalities_with_mock(executor, mock_stub)
        type_signature = computation_types.FunctionType(None, tf.int32)
        fn = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                         type_signature, executor)

        result = asyncio.run(executor.create_call(fn, None))

        mock_stub.create_call.assert_called_once()
        self.assertIsInstance(result, remote_executor.RemoteValue)
    def test_create_call_reraises_grpc_error(self, mock_stub):
        mock_stub.create_call = mock.Mock(
            side_effect=_raise_non_retryable_grpc_error)
        executor = remote_executor.RemoteExecutor(mock_stub)
        _set_cardinalities_with_mock(executor, mock_stub)
        type_signature = computation_types.FunctionType(None, tf.int32)
        comp = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                           type_signature, executor)

        with self.assertRaises(grpc.RpcError) as context:
            asyncio.run(executor.create_call(comp, None))

        self.assertEqual(context.exception.code(), grpc.StatusCode.ABORTED)
    def test_create_selection_returns_remote_value(self, mock_stub):
        mock_stub.create_selection.return_value = executor_pb2.CreateSelectionResponse(
        )
        executor = remote_executor.RemoteExecutor(mock_stub)
        _set_cardinalities_with_mock(executor, mock_stub)
        type_signature = computation_types.StructType([tf.int32, tf.int32])
        source = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                             type_signature, executor)

        result = asyncio.run(executor.create_selection(source, 0))

        mock_stub.create_selection.assert_called_once()
        self.assertIsInstance(result, remote_executor.RemoteValue)
    def test_create_selection_returns_remote_value(self, mock_stub):
        response = executor_pb2.ExecuteResponse(
            create_selection=executor_pb2.CreateSelectionResponse())
        executor = _setup_mock_streaming_executor(mock_stub, response)
        loop = asyncio.get_event_loop()
        type_signature = computation_types.StructType([tf.int32, tf.int32])
        source = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                             type_signature, executor)

        result = loop.run_until_complete(
            executor.create_selection(source, index=0))

        self.assertIsInstance(result, remote_executor.RemoteValue)
    def test_create_selection_raises_retryable_error_on_grpc_error_unavailable(
            self, mock_stub):
        instance = mock_stub.return_value
        instance.CreateSelection = mock.Mock(
            side_effect=_raise_grpc_error_unavailable)
        loop = asyncio.get_event_loop()
        executor = create_remote_executor()
        type_signature = computation_types.NamedTupleType([tf.int32, tf.int32])
        source = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                             type_signature, executor)

        with self.assertRaises(execution_context.RetryableError):
            loop.run_until_complete(executor.create_selection(source, index=0))
  def test_create_call_reraises_grpc_error(self, mock_stub):
    instance = mock_stub.return_value
    instance.CreateCall = mock.Mock(side_effect=_raise_non_retryable_grpc_error)
    loop = asyncio.get_event_loop()
    executor = create_remote_executor()
    type_signature = computation_types.FunctionType(None, tf.int32)
    comp = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature,
                                       executor)

    with self.assertRaises(grpc.RpcError) as context:
      loop.run_until_complete(executor.create_call(comp, None))

    self.assertEqual(context.exception.code(), grpc.StatusCode.ABORTED)
Exemple #11
0
    def test_create_call_returns_remote_value(self, mock_stub):
        response = executor_pb2.CreateCallResponse()
        instance = mock_stub.return_value
        instance.CreateCall = mock.Mock(side_effect=[response])
        loop = asyncio.get_event_loop()
        executor = create_remote_executor()
        type_signature = computation_types.FunctionType(None, tf.int32)
        fn = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                         type_signature, executor)

        result = loop.run_until_complete(executor.create_call(fn, None))

        instance.CreateCall.assert_called_once()
        self.assertIsInstance(result, remote_executor.RemoteValue)
  def test_create_selection_reraises_non_retryable_grpc_error(self, mock_stub):
    instance = mock_stub.return_value
    instance.CreateSelection = mock.Mock(
        side_effect=_raise_non_retryable_grpc_error)
    loop = asyncio.get_event_loop()
    executor = create_remote_executor()
    type_signature = computation_types.NamedTupleType([tf.int32, tf.int32])
    source = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                         type_signature, executor)

    with self.assertRaises(grpc.RpcError) as context:
      loop.run_until_complete(executor.create_selection(source, index=0))

    self.assertEqual(context.exception.code(), grpc.StatusCode.ABORTED)
  def test_create_selection_returns_remote_value(self, mock_stub):
    response = executor_pb2.CreateSelectionResponse()
    instance = mock_stub.return_value
    instance.CreateSelection = mock.Mock(side_effect=[response])
    loop = asyncio.get_event_loop()
    executor = create_remote_executor()
    type_signature = computation_types.NamedTupleType([tf.int32, tf.int32])
    source = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                         type_signature, executor)

    result = loop.run_until_complete(executor.create_selection(source, index=0))

    instance.CreateSelection.assert_called_once()
    self.assertIsInstance(result, remote_executor.RemoteValue)
    def test_compute_reraises_grpc_error(self, grpc_stub):
        instance = grpc_stub.return_value
        instance.Compute = mock.Mock(
            side_effect=_raise_non_retryable_grpc_error)

        stub = create_stub()
        request = executor_pb2.ComputeRequest(
            executor=executor_pb2.ExecutorId(),
            value_ref=executor_pb2.ValueRef())

        with self.assertRaises(grpc.RpcError) as context:
            stub.compute(request)

        self.assertEqual(context.exception.code(), grpc.StatusCode.ABORTED)
Exemple #15
0
    def test_compute_reraises_grpc_error_deadline_exceeded(self, mock_stub):
        instance = mock_stub.return_value
        instance.Compute = mock.Mock(
            side_effect=_raise_grpc_error_deadline_exceeded)
        loop = asyncio.get_event_loop()
        executor = create_remote_executor()
        type_signature = computation_types.FunctionType(None, tf.int32)
        comp = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                           type_signature, executor)

        with self.assertRaises(grpc.RpcError) as context:
            loop.run_until_complete(comp.compute())

        self.assertEqual(context.exception.code(),
                         grpc.StatusCode.DEADLINE_EXCEEDED)
Exemple #16
0
    def test_compute_returns_result(self, mock_stub):
        tensor_proto = tf.make_tensor_proto(1)
        any_pb = any_pb2.Any()
        any_pb.Pack(tensor_proto)
        value = executor_pb2.Value(tensor=any_pb)
        response = executor_pb2.ComputeResponse(value=value)
        instance = mock_stub.return_value
        instance.Compute = mock.Mock(side_effect=[response])
        loop = asyncio.get_event_loop()
        executor = create_remote_executor()
        type_signature = computation_types.FunctionType(None, tf.int32)
        comp = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                           type_signature, executor)

        result = loop.run_until_complete(comp.compute())

        instance.Compute.assert_called_once()
        self.assertEqual(result, 1)
    def test_compute_returns_result(self, mock_stub):
        tensor_proto = tf.make_tensor_proto(1)
        any_pb = any_pb2.Any()
        any_pb.Pack(tensor_proto)
        value = executor_pb2.Value(tensor=any_pb)
        mock_stub.compute.return_value = executor_pb2.ComputeResponse(
            value=value)
        executor = remote_executor.RemoteExecutor(mock_stub)
        _set_cardinalities_with_mock(executor, mock_stub)
        executor.set_cardinalities({placements.CLIENTS: 3})
        type_signature = computation_types.FunctionType(None, tf.int32)
        comp = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                           type_signature, executor)

        result = asyncio.run(comp.compute())

        mock_stub.compute.assert_called_once()
        self.assertEqual(result, 1)
    def test_compute_returns_result(self, mock_executor_grpc_stub):
        tensor_proto = tf.make_tensor_proto(1)
        any_pb = any_pb2.Any()
        any_pb.Pack(tensor_proto)
        value = executor_pb2.Value(tensor=any_pb)
        response = executor_pb2.ComputeResponse(value=value)
        instance = mock_executor_grpc_stub.return_value
        instance.Compute = mock.Mock(side_effect=[response])

        request = executor_pb2.ComputeRequest(
            executor=executor_pb2.ExecutorId(),
            value_ref=executor_pb2.ValueRef())

        stub = create_stub()
        result = stub.compute(request)

        instance.Compute.assert_called_once()

        value, _ = value_serialization.deserialize_value(result.value)
        self.assertEqual(value, 1)
Exemple #19
0
 def CreateValue(
     self,
     request: executor_pb2.CreateValueRequest,
     context: grpc.ServicerContext,
 ) -> executor_pb2.CreateValueResponse:
   """Creates a value embedded in the executor."""
   py_typecheck.check_type(request, executor_pb2.CreateValueRequest)
   try:
     value, value_type = (
         executor_service_utils.deserialize_value(request.value))
     value_id = str(uuid.uuid4())
     future_val = asyncio.run_coroutine_threadsafe(
         self._executor.create_value(value, value_type), self._event_loop)
     with self._lock:
       self._values[value_id] = future_val
     return executor_pb2.CreateValueResponse(
         value_ref=executor_pb2.ValueRef(id=value_id))
   except (ValueError, TypeError) as err:
     _set_invalid_arg_err(context, err)
     return executor_pb2.CreateValueResponse()
 def CreateValue(
     self,
     request: executor_pb2.CreateValueRequest,
     context: grpc.ServicerContext,
 ) -> executor_pb2.CreateValueResponse:
     """Creates a value embedded in the executor."""
     py_typecheck.check_type(request, executor_pb2.CreateValueRequest)
     with self._try_handle_request_context(
             request, context, executor_pb2.CreateValueResponse):
         with tracing.span('ExecutorService.CreateValue',
                           'deserialize_value'):
             value, value_type = (value_serialization.deserialize_value(
                 request.value))
         value_id = str(uuid.uuid4())
         coro = self.executor(request,
                              context).create_value(value, value_type)
         future_val = self._run_coro_threadsafe_with_tracing(coro)
         with self._lock:
             self._values[value_id] = future_val
         return executor_pb2.CreateValueResponse(
             value_ref=executor_pb2.ValueRef(id=value_id))
 def CreateValue(
     self,
     request: executor_pb2.CreateValueRequest,
     context: grpc.ServicerContext,
 ) -> executor_pb2.CreateValueResponse:
   """Creates a value embedded in the executor."""
   py_typecheck.check_type(request, executor_pb2.CreateValueRequest)
   try:
     with tracing.span('ExecutorService.CreateValue', 'deserialize_value'):
       value, value_type = (
           executor_service_utils.deserialize_value(request.value))
     value_id = str(uuid.uuid4())
     coro = self._executor.create_value(value, value_type)
     future_val = self._run_coro_threadsafe_with_tracing(coro)
     with self._lock:
       self._values[value_id] = future_val
     return executor_pb2.CreateValueResponse(
         value_ref=executor_pb2.ValueRef(id=value_id))
   except (ValueError, TypeError) as err:
     _set_invalid_arg_err(context, err)
     return executor_pb2.CreateValueResponse()
Exemple #22
0
    def CreateTuple(self, request, context):
        """Creates a tuple embedded in the executor.

    Args:
      request: An instance of `executor_pb2.CreateTupleRequest`.
      context: An instance of `grpc.ServicerContext`.

    Returns:
      An instance of `executor_pb2.CreateTupleResponse`.
    """
        py_typecheck.check_type(request, executor_pb2.CreateTupleRequest)
        try:
            with self._lock:
                elem_futures = [
                    self._values[e.value_ref.id] for e in request.element
                ]
            elem_names = [
                str(elem.name) if elem.name else None
                for elem in request.element
            ]

            async def _processing():
                elem_values = await asyncio.gather(
                    *[asyncio.wrap_future(v) for v in elem_futures])
                elements = list(zip(elem_names, elem_values))
                anon_tuple = anonymous_tuple.AnonymousTuple(elements)
                return await self._executor.create_tuple(anon_tuple)

            result_fut = self._event_loop.create_task(_processing())
            result_id = str(uuid.uuid4())
            with self._lock:
                self._values[result_id] = result_fut
            return executor_pb2.CreateTupleResponse(
                value_ref=executor_pb2.ValueRef(id=result_id))
        except (ValueError, TypeError) as err:
            logging.error(traceback.format_exc())
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            context.set_details(str(err))
            return executor_pb2.CreateTupleResponse()
    def CreateValue(self, request, context):
        """Creates a value embedded in the executor.

    Args:
      request: An instance of `executor_pb2.CreateValueRequest`.
      context: An instance of `grpc.ServicerContext`.

    Returns:
      An instance of `executor_pb2.CreateValueResponse`.
    """
        py_typecheck.check_type(request, executor_pb2.CreateValueRequest)
        try:
            value, value_type = (executor_service_utils.deserialize_value(
                request.value))
            value_id = str(uuid.uuid4())
            future_val = self._executor.create_value(value, value_type)
            with self._lock:
                self._values[value_id] = future_val
            return executor_pb2.CreateValueResponse(
                value_ref=executor_pb2.ValueRef(id=value_id))
        except (ValueError, TypeError):
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            return executor_pb2.CreateValueResponse()
Exemple #24
0
    def CreateSelection(self, request, context):
        """Creates a selection embedded in the executor.

    Args:
      request: An instance of `executor_pb2.CreateSelectionRequest`.
      context: An instance of `grpc.ServicerContext`.

    Returns:
      An instance of `executor_pb2.CreateSelectionResponse`.
    """
        py_typecheck.check_type(request, executor_pb2.CreateSelectionRequest)
        try:
            with self._lock:
                source_fut = self._values[request.source_ref.id]

            async def _processing():
                source = await asyncio.wrap_future(source_fut)
                which_selection = request.WhichOneof('selection')
                if which_selection == 'name':
                    coro = self._executor.create_selection(source,
                                                           name=request.name)
                else:
                    coro = self._executor.create_selection(source,
                                                           index=request.index)
                return await coro

            result_fut = self._event_loop.create_task(_processing())
            result_id = str(uuid.uuid4())
            with self._lock:
                self._values[result_id] = result_fut
            return executor_pb2.CreateSelectionResponse(
                value_ref=executor_pb2.ValueRef(id=result_id))
        except (ValueError, TypeError) as err:
            logging.error(traceback.format_exc())
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            context.set_details(str(err))
            return executor_pb2.CreateSelectionResponse()
Exemple #25
0
    def CreateCall(self, request, context):
        """Creates a call embedded in the executor.

    Args:
      request: An instance of `executor_pb2.CreateCallRequest`.
      context: An instance of `grpc.ServicerContext`.

    Returns:
      An instance of `executor_pb2.CreateCallResponse`.
    """
        py_typecheck.check_type(request, executor_pb2.CreateCallRequest)
        try:
            function_id = str(request.function_ref.id)
            argument_id = str(request.argument_ref.id)
            with self._lock:
                function_val = self._values[function_id]
                argument_val = self._values[
                    argument_id] if argument_id else None

            async def _processing():
                function = await asyncio.wrap_future(function_val)
                argument = await asyncio.wrap_future(
                    argument_val) if argument_val is not None else None
                return await self._executor.create_call(function, argument)

            result_fut = self._event_loop.create_task(_processing())
            result_id = str(uuid.uuid4())
            with self._lock:
                self._values[result_id] = result_fut
            return executor_pb2.CreateCallResponse(
                value_ref=executor_pb2.ValueRef(id=result_id))
        except (ValueError, TypeError) as err:
            logging.error(traceback.format_exc())
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            context.set_details(str(err))
            return executor_pb2.CreateCallResponse()
    def CreateSelection(
        self,
        request: executor_pb2.CreateSelectionRequest,
        context: grpc.ServicerContext,
    ) -> executor_pb2.CreateSelectionResponse:
        """Creates a selection embedded in the executor."""
        py_typecheck.check_type(request, executor_pb2.CreateSelectionRequest)
        with self._try_handle_request_context(
                request, context, executor_pb2.CreateSelectionResponse):
            with self._lock:
                source_fut = self._values[request.source_ref.id]

            async def _process_create_selection():
                source = await asyncio.wrap_future(source_fut)
                return await self.executor(request, context).create_selection(
                    source, request.index)

            result_fut = self._run_coro_threadsafe_with_tracing(
                _process_create_selection())
            result_id = str(uuid.uuid4())
            with self._lock:
                self._values[result_id] = result_fut
            return executor_pb2.CreateSelectionResponse(
                value_ref=executor_pb2.ValueRef(id=result_id))