예제 #1
0
    def CreateTuple(self, request, context):
        """Creates a tuple embedded in the executor.

    Args:
      request: An instance of `executor_pb2.CreateTupleRequest`.
      context: An instance of `grpc.ServicerContext`.

    Returns:
      An instance of `executor_pb2.CreateTupleResponse`.
    """
        py_typecheck.check_type(request, executor_pb2.CreateTupleRequest)
        try:
            with self._lock:
                element_vals = [
                    self._values[e.value_ref.id] for e in request.element
                ]
            elements = []
            for idx, elem in enumerate(request.element):
                elements.append((str(elem.name) if elem.name else None,
                                 element_vals[idx].result()))
            anon_tuple = anonymous_tuple.AnonymousTuple(elements)
            result_val = asyncio.run_coroutine_threadsafe(
                self._executor.create_tuple(anon_tuple), self._event_loop)
            result_id = str(uuid.uuid4())
            with self._lock:
                self._values[result_id] = result_val
            return executor_pb2.CreateTupleResponse(
                value_ref=executor_pb2.ValueRef(id=result_id))
        except (ValueError, TypeError) as err:
            logging.error(traceback.format_exc())
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            context.set_details(str(err))
            return executor_pb2.CreateTupleResponse()
예제 #2
0
  def CreateTuple(
      self,
      request: executor_pb2.CreateTupleRequest,
      context: grpc.ServicerContext,
  ) -> executor_pb2.CreateTupleResponse:
    """Creates a tuple embedded in the executor."""
    py_typecheck.check_type(request, executor_pb2.CreateTupleRequest)
    try:
      with self._lock:
        elem_futures = [self._values[e.value_ref.id] for e in request.element]
      elem_names = [
          str(elem.name) if elem.name else None for elem in request.element
      ]

      async def _processing():
        elem_values = await asyncio.gather(
            *[asyncio.wrap_future(v) for v in elem_futures])
        elements = list(zip(elem_names, elem_values))
        anon_tuple = anonymous_tuple.AnonymousTuple(elements)
        return await self._executor.create_tuple(anon_tuple)

      result_fut = self._event_loop.create_task(_processing())
      result_id = str(uuid.uuid4())
      with self._lock:
        self._values[result_id] = result_fut
      return executor_pb2.CreateTupleResponse(
          value_ref=executor_pb2.ValueRef(id=result_id))
    except (ValueError, TypeError) as err:
      _set_invalid_arg_err(context, err)
      return executor_pb2.CreateTupleResponse()
예제 #3
0
    def CreateTuple(self, request, context):
        """Creates a tuple embedded in the executor.

    Args:
      request: An instance of `executor_pb2.CreateTupleRequest`.
      context: An instance of `grpc.ServicerContext`.

    Returns:
      An instance of `executor_pb2.CreateTupleResponse`.
    """
        py_typecheck.check_type(request, executor_pb2.CreateTupleRequest)
        try:
            with self._lock:
                elem_futures = [
                    self._values[e.value_ref.id] for e in request.element
                ]
            elem_names = [
                str(elem.name) if elem.name else None
                for elem in request.element
            ]

            async def _processing():
                elem_values = await asyncio.gather(
                    *[asyncio.wrap_future(v) for v in elem_futures])
                elements = list(zip(elem_names, elem_values))
                anon_tuple = anonymous_tuple.AnonymousTuple(elements)
                return await self._executor.create_tuple(anon_tuple)

            result_fut = self._event_loop.create_task(_processing())
            result_id = str(uuid.uuid4())
            with self._lock:
                self._values[result_id] = result_fut
            return executor_pb2.CreateTupleResponse(
                value_ref=executor_pb2.ValueRef(id=result_id))
        except (ValueError, TypeError) as err:
            logging.error(traceback.format_exc())
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            context.set_details(str(err))
            return executor_pb2.CreateTupleResponse()
예제 #4
0
  def test_create_tuple_returns_remote_value(self, mock_stub):
    response = executor_pb2.CreateTupleResponse()
    instance = mock_stub.return_value
    instance.CreateTuple = mock.Mock(side_effect=[response])
    loop = asyncio.get_event_loop()
    executor = create_remote_executor()
    type_signature = computation_types.TensorType(tf.int32)
    value_1 = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                          type_signature, executor)
    value_2 = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                          type_signature, executor)

    result = loop.run_until_complete(executor.create_tuple([value_1, value_2]))

    instance.CreateTuple.assert_called_once()
    self.assertIsInstance(result, remote_executor.RemoteValue)