Esempio n. 1
0
    def CreateSelection(self, request, context):
        """Creates a selection embedded in the executor.

    Args:
      request: An instance of `executor_pb2.CreateSelectionRequest`.
      context: An instance of `grpc.ServicerContext`.

    Returns:
      An instance of `executor_pb2.CreateSelectionResponse`.
    """
        py_typecheck.check_type(request, executor_pb2.CreateSelectionRequest)
        try:
            with self._lock:
                source_val = self._values[request.source_ref.id]
            source = source_val.result()
            which_selection = request.WhichOneof('selection')
            if which_selection == 'name':
                coro = self._executor.create_selection(source,
                                                       name=request.name)
            else:
                coro = self._executor.create_selection(source,
                                                       index=request.index)
            result_val = asyncio.run_coroutine_threadsafe(
                coro, self._event_loop)
            result_id = str(uuid.uuid4())
            with self._lock:
                self._values[result_id] = result_val
            return executor_pb2.CreateSelectionResponse(
                value_ref=executor_pb2.ValueRef(id=result_id))
        except (ValueError, TypeError) as err:
            logging.error(traceback.format_exc())
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            context.set_details(str(err))
            return executor_pb2.CreateSelectionResponse()
Esempio n. 2
0
    def CreateSelection(
        self,
        request: executor_pb2.CreateSelectionRequest,
        context: grpc.ServicerContext,
    ) -> executor_pb2.CreateSelectionResponse:
        """Creates a selection embedded in the executor."""
        py_typecheck.check_type(request, executor_pb2.CreateSelectionRequest)
        try:
            with self._lock:
                source_fut = self._values[request.source_ref.id]

            async def _processing():
                source = await asyncio.wrap_future(source_fut)
                which_selection = request.WhichOneof('selection')
                if which_selection == 'name':
                    coro = self.executor.create_selection(source,
                                                          name=request.name)
                else:
                    coro = self.executor.create_selection(source,
                                                          index=request.index)
                return await coro

            result_fut = self._run_coro_threadsafe_with_tracing(_processing())
            result_id = str(uuid.uuid4())
            with self._lock:
                self._values[result_id] = result_fut
            return executor_pb2.CreateSelectionResponse(
                value_ref=executor_pb2.ValueRef(id=result_id))
        except (ValueError, TypeError) as err:
            _set_invalid_arg_err(context, err)
            return executor_pb2.CreateSelectionResponse()
    def test_create_selection_returns_value(self, mock_executor_grpc_stub):
        response = executor_pb2.CreateSelectionResponse()
        instance = mock_executor_grpc_stub.return_value
        instance.CreateSelection = mock.Mock(side_effect=[response])
        stub = create_stub()

        result = stub.create_selection(
            request=executor_pb2.CreateSelectionRequest())

        instance.CreateSelection.assert_called_once()
        self.assertEqual(result, response)
    def test_create_selection_returns_remote_value(self, mock_stub):
        mock_stub.create_selection.return_value = executor_pb2.CreateSelectionResponse(
        )
        executor = remote_executor.RemoteExecutor(mock_stub)
        _set_cardinalities_with_mock(executor, mock_stub)
        type_signature = computation_types.StructType([tf.int32, tf.int32])
        source = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                             type_signature, executor)

        result = asyncio.run(executor.create_selection(source, 0))

        mock_stub.create_selection.assert_called_once()
        self.assertIsInstance(result, remote_executor.RemoteValue)
    def test_create_selection_returns_remote_value(self, mock_stub):
        response = executor_pb2.ExecuteResponse(
            create_selection=executor_pb2.CreateSelectionResponse())
        executor = _setup_mock_streaming_executor(mock_stub, response)
        loop = asyncio.get_event_loop()
        type_signature = computation_types.StructType([tf.int32, tf.int32])
        source = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                             type_signature, executor)

        result = loop.run_until_complete(
            executor.create_selection(source, index=0))

        self.assertIsInstance(result, remote_executor.RemoteValue)
Esempio n. 6
0
  def test_create_selection_returns_remote_value(self, mock_stub):
    response = executor_pb2.CreateSelectionResponse()
    instance = mock_stub.return_value
    instance.CreateSelection = mock.Mock(side_effect=[response])
    loop = asyncio.get_event_loop()
    executor = create_remote_executor()
    type_signature = computation_types.NamedTupleType([tf.int32, tf.int32])
    source = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                         type_signature, executor)

    result = loop.run_until_complete(executor.create_selection(source, index=0))

    instance.CreateSelection.assert_called_once()
    self.assertIsInstance(result, remote_executor.RemoteValue)
Esempio n. 7
0
    def CreateSelection(
        self,
        request: executor_pb2.CreateSelectionRequest,
        context: grpc.ServicerContext,
    ) -> executor_pb2.CreateSelectionResponse:
        """Creates a selection embedded in the executor."""
        py_typecheck.check_type(request, executor_pb2.CreateSelectionRequest)
        with self._try_handle_request_context(
                request, context, executor_pb2.CreateSelectionResponse):
            with self._lock:
                source_fut = self._values[request.source_ref.id]

            async def _process_create_selection():
                source = await asyncio.wrap_future(source_fut)
                return await self.executor(request, context).create_selection(
                    source, request.index)

            result_fut = self._run_coro_threadsafe_with_tracing(
                _process_create_selection())
            result_id = str(uuid.uuid4())
            with self._lock:
                self._values[result_id] = result_fut
            return executor_pb2.CreateSelectionResponse(
                value_ref=executor_pb2.ValueRef(id=result_id))