示例#1
0
    def __init__(self, logger=None, task_id=None, server_address=None, agent_token=None, ignore_errors=False,
                 ignore_task_id=False):
        self._ignore_task_id = ignore_task_id
        self.logger = take_with_default(logger, default_logger)
        self._ignore_errors = ignore_errors
        self.task_id = take_with_default(task_id, os.environ["TASK_ID"])
        self.server_address = take_with_default(server_address, os.environ[SERVER_ADDRESS])
        self.agent_token = take_with_default(agent_token, os.environ[AGENT_TOKEN])
        self.public_api = Api.from_env(ignore_task_id=self._ignore_task_id)
        self._app_url = self.public_api.app.get_url(self.task_id)
        self._session_dir = "/sessions/{}".format(self.task_id)

        self.api = AgentAPI(token=self.agent_token, server_address=self.server_address, ext_logger=self.logger)
        self.api.add_to_metadata('x-task-id', str(self.task_id))

        self.callbacks = {}
        self.processing_queue = queue.Queue()#(maxsize=self.QUEUE_MAX_SIZE)
        self.logger.debug('App is created', extra={"task_id": self.task_id, "server_address": self.server_address})

        self._ignore_stop_for_debug = False
        self._error = None
        self.stop_event = asyncio.Event()

        self.executor = concurrent.futures.ThreadPoolExecutor()
        self.loop = asyncio.get_event_loop()
        # May want to catch other signals too
        signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT, signal.SIGQUIT)
        for s in signals:
            self.loop.add_signal_handler(s, lambda s=s: asyncio.create_task(self._shutdown(signal=s)))
        # comment out the line below to see how unhandled exceptions behave
        self.loop.set_exception_handler(self.handle_exception)
示例#2
0
class AppService:
    NETW_CHUNK_SIZE = 1048576
    QUEUE_MAX_SIZE = 2000  # Maximum number of in-flight requests to avoid exhausting server memory.

    def __init__(self, logger, task_config):
        self.logger = logger
        self.server_address = task_config['server_address']
        self.api = AgentAPI(token=task_config['agent_token'],
                            server_address=self.server_address,
                            ext_logger=self.logger)

        self.api.add_to_metadata('x-task-id', str(task_config['task_id']))

        self.routes = {}

        self.thread_pool = concurrent.futures.ThreadPoolExecutor(
            max_workers=10)
        self.processing_queue = Queue(maxsize=self.QUEUE_MAX_SIZE)
        self.logger.debug('Created AgentRPCServicer', extra=task_config)

    def add_route(self, route, func):
        self.routes[route] = func

    def _processing(self):
        while True:
            request_msg = self.processing_queue.get(block=True, timeout=None)
            try:
                self.routes[request_msg["command"]](request_msg)
            except Exception as e:
                self.logger.error(traceback.format_exc(),
                                  exc_info=True,
                                  extra={'exc_str': str(e)})

    def run(self):
        def seq_inf_wrapped():
            function_wrapper(self._processing)  # exit if raised

        process_thread = threading.Thread(target=seq_inf_wrapped, daemon=True)
        process_thread.start()

        for gen_event in self.api.get_endless_stream('GetGeneralEventsStream',
                                                     api_proto.GeneralEvent,
                                                     api_proto.Empty()):
            try:
                data = {}
                if gen_event.data is not None and gen_event.data != b'':
                    data = json.loads(gen_event.data.decode('utf-8'))

                event_obj = {REQUEST_ID: gen_event.request_id, **data}
                self.processing_queue.put(event_obj, block=True)
            except Exception as error:
                self.logger.warning('App exception: ',
                                    extra={"error_message": str(error)})

        raise ConnectionClosedByServerException(
            'Requests stream to a deployed model closed by the server.')
示例#3
0
    def __init__(self, logger, task_config):
        self.logger = logger
        self.server_address = task_config['server_address']
        self.api = AgentAPI(token=task_config['agent_token'],
                            server_address=self.server_address,
                            ext_logger=self.logger)

        self.api.add_to_metadata('x-task-id', str(task_config['task_id']))

        self.routes = {}

        self.thread_pool = concurrent.futures.ThreadPoolExecutor(
            max_workers=10)
        self.processing_queue = Queue(maxsize=self.QUEUE_MAX_SIZE)
        self.logger.debug('Created AgentRPCServicer', extra=task_config)
示例#4
0
    def __init__(self, logger, model_applier: SingleImageInferenceInterface, conn_config, cache):
        self.logger = logger
        self.api = AgentAPI(token=conn_config['token'],
                            server_address=conn_config['server_address'],
                            ext_logger=self.logger)
        self.api.add_to_metadata('x-task-id', conn_config['task_id'])

        self.model_applier = model_applier
        self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=10)
        self.download_queue = Queue(maxsize=self.QUEUE_MAX_SIZE)
        self.final_processing_queue = Queue(maxsize=self.QUEUE_MAX_SIZE)
        self.image_cache = cache
        self._default_inference_mode_config = InfModeFullImage.make_default_config(
            model_result_suffix=MODEL_RESULT_SUFFIX)
        self.logger.info('Created AgentRPCServicer', extra=conn_config)
class AgentRPCServicerBase:
    NETW_CHUNK_SIZE = 1048576
    QUEUE_MAX_SIZE = 2000  # Maximum number of in-flight requests to avoid exhausting server memory.

    def __init__(self, logger, model_applier: SingleImageInferenceInterface,
                 conn_config, cache):
        self.logger = logger
        self.api = AgentAPI(token=conn_config['token'],
                            server_address=conn_config['server_address'],
                            ext_logger=self.logger)
        self.api.add_to_metadata('x-task-id', conn_config['task_id'])

        self.model_applier = model_applier
        self.thread_pool = concurrent.futures.ThreadPoolExecutor(
            max_workers=10)
        self.download_queue = Queue(maxsize=self.QUEUE_MAX_SIZE)
        self.final_processing_queue = Queue(maxsize=self.QUEUE_MAX_SIZE)
        self.image_cache = cache
        self._default_inference_mode_config = InfModeFullImage.make_default_config(
            model_result_suffix=MODEL_RESULT_SUFFIX)
        self.logger.info('Created AgentRPCServicer', extra=conn_config)

    def _load_image_from_sly(self, req_id, image_hash, src_node_token):
        self.logger.trace('Will look for image.',
                          extra={
                              'request_id': req_id,
                              'image_hash': image_hash,
                              'src_node_token': src_node_token
                          })
        img_data = self.image_cache.get(image_hash)
        if img_data is None:
            img_data_packed = download_image_from_remote(
                self.api, image_hash, src_node_token, self.logger)
            img_data = decode_image(img_data_packed)
            self.image_cache.add(image_hash, img_data)

        return img_data

    def _load_arbitrary_image(self, req_id):
        self.logger.trace('Will load arbitrary image.',
                          extra={'request_id': req_id})
        img_data_packed = download_data_from_remote(self.api, req_id,
                                                    self.logger)
        img_data = decode_image(img_data_packed)
        return img_data

    def _load_data_if_required(self, event_obj):
        try:
            req_id = event_obj['request_id']
            event_data = event_obj[DATA]
            request_type = event_data.get(REQUEST_TYPE, INFERENCE)
            if request_type == INFERENCE:
                # For inference we need to download an image and add it to the event data.
                image_hash = event_data.get('image_hash')
                if image_hash is None:
                    img_data = self._load_arbitrary_image(req_id)
                else:
                    src_node_token = event_obj['data'].get(
                        'src_node_token', '')
                    img_data = self._load_image_from_sly(
                        req_id, image_hash, src_node_token)
                event_data['image_arr'] = img_data
                self.logger.trace('Input image obtained.',
                                  extra={'request_id': req_id})
            self.final_processing_queue.put(item=(event_data, req_id))
        except Exception as e:
            res_msg = {}
            self.logger.error(traceback.format_exc(),
                              exc_info=True,
                              extra={'exc_str': str(e)})
            res_msg.update({
                'success': False,
                'error': json.dumps(traceback.format_exc())
            })
            self.thread_pool.submit(function_wrapper_nofail, self._send_data,
                                    res_msg, req_id)  # skip errors

    def _send_data(self, out_msg, req_id):
        self.logger.trace('Will send output data.',
                          extra={'request_id': req_id})
        out_bytes = json.dumps(out_msg).encode('utf-8')

        self.api.put_stream_with_data('SendGeneralEventData',
                                      api_proto.Empty,
                                      send_from_memory_generator(
                                          out_bytes, self.NETW_CHUNK_SIZE),
                                      addit_headers={'x-request-id': req_id})
        self.logger.trace('Output data is sent.', extra={'request_id': req_id})

    def _final_processing(self, in_msg):
        request_type = in_msg.get(REQUEST_TYPE, INFERENCE)

        if request_type == INFERENCE:
            img = in_msg['image_arr']
            if len(img.shape) != 3 or img.shape[2] not in [3, 4]:
                raise RuntimeError(
                    'Expect 3- or 4-channel image RGB(RGBA) [0..255].')
            elif img.shape[2] == 4:
                img = drop_image_alpha_channel(img)
            return self._do_single_img_inference(img, in_msg)
        elif request_type == GET_OUT_META:
            return {'out_meta': self._get_out_meta(in_msg).to_json()}
        else:
            raise RuntimeError(
                'Unknown request type: {}. Only the following request types are supported: {}'
                .format(request_type, SUPPORTED_REQUEST_TYPES))

    def _do_single_img_inference(self, img, in_msg):
        raise NotImplementedError()

    def _get_out_meta(self, in_msg):
        raise NotImplementedError()

    def _sequential_final_processing(self):
        while True:
            in_msg, req_id = self.final_processing_queue.get(block=True,
                                                             timeout=None)
            res_msg = {}
            try:
                res_msg.update(self._final_processing(in_msg))
                res_msg.update({'success': True})
            except Exception as e:
                self.logger.error(traceback.format_exc(),
                                  exc_info=True,
                                  extra={'exc_str': str(e)})
                res_msg.update({
                    'success': False,
                    'error': json.dumps(traceback.format_exc())
                })

            self.thread_pool.submit(function_wrapper_nofail, self._send_data,
                                    res_msg, req_id)  # skip errors

    def _load_data_loop(self):
        while True:
            event_obj = self.download_queue.get(block=True, timeout=None)
            self._load_data_if_required(event_obj)

    def run_inf_loop(self):
        def seq_inf_wrapped():
            function_wrapper(
                self._sequential_final_processing)  # exit if raised

        load_data_thread = threading.Thread(target=self._load_data_loop,
                                            daemon=True)
        load_data_thread.start()
        inference_thread = threading.Thread(target=seq_inf_wrapped,
                                            daemon=True)
        inference_thread.start()
        report_agent_rpc_ready()

        for gen_event in self.api.get_endless_stream('GetGeneralEventsStream',
                                                     api_proto.GeneralEvent,
                                                     api_proto.Empty()):
            try:
                request_id = gen_event.request_id

                data = {}
                if gen_event.data is not None and gen_event.data != b'':
                    data = json.loads(gen_event.data.decode('utf-8'))

                event_obj = {'request_id': request_id, 'data': data}
                self.logger.debug('GET_INFERENCE_CALL', extra=event_obj)
                self.download_queue.put(event_obj, block=True)
            except Exception as error:
                self.logger.warning('Inference exception: ',
                                    extra={"error_message": str(error)})
                res_msg = {'success': False, 'error': json.dumps(str(error))}
                self.thread_pool.submit(function_wrapper_nofail,
                                        self._send_data, res_msg, request_id)
示例#6
0
class AppService:
    NETW_CHUNK_SIZE = 1048576
    QUEUE_MAX_SIZE = 2000  # Maximum number of in-flight requests to avoid exhausting server memory.
    DEFAULT_EVENTS = [STOP_COMMAND, *IMAGE_ANNOTATION_EVENTS]

    def __init__(self, logger=None, task_id=None, server_address=None, agent_token=None, ignore_errors=False,
                 ignore_task_id=False):
        self._ignore_task_id = ignore_task_id
        self.logger = take_with_default(logger, default_logger)
        self._ignore_errors = ignore_errors
        self.task_id = take_with_default(task_id, os.environ["TASK_ID"])
        self.server_address = take_with_default(server_address, os.environ[SERVER_ADDRESS])
        self.agent_token = take_with_default(agent_token, os.environ[AGENT_TOKEN])
        self.public_api = Api.from_env(ignore_task_id=self._ignore_task_id)
        self._app_url = self.public_api.app.get_url(self.task_id)
        self._session_dir = "/sessions/{}".format(self.task_id)

        self.api = AgentAPI(token=self.agent_token, server_address=self.server_address, ext_logger=self.logger)
        self.api.add_to_metadata('x-task-id', str(self.task_id))

        self.callbacks = {}
        self.processing_queue = queue.Queue()#(maxsize=self.QUEUE_MAX_SIZE)
        self.logger.debug('App is created', extra={"task_id": self.task_id, "server_address": self.server_address})

        self._ignore_stop_for_debug = False
        self._error = None
        self.stop_event = asyncio.Event()

        self.executor = concurrent.futures.ThreadPoolExecutor()
        self.loop = asyncio.get_event_loop()
        # May want to catch other signals too
        signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT, signal.SIGQUIT)
        for s in signals:
            self.loop.add_signal_handler(s, lambda s=s: asyncio.create_task(self._shutdown(signal=s)))
        # comment out the line below to see how unhandled exceptions behave
        self.loop.set_exception_handler(self.handle_exception)

    def handle_exception(self, loop, context):
        # context["message"] will always be there; but context["exception"] may not
        msg = context.get("exception", context["message"])
        if isinstance(msg, Exception):
            #self.logger.error(traceback.format_exc(), exc_info=True, extra={'exc_str': str(msg), 'future_info': context["future"]})
            self.logger.error(msg, exc_info=True, extra={'future_info': context["future"]})
        else:
            self.logger.error("Caught exception: {}".format(msg))

        self.logger.info("Shutting down...")
        asyncio.create_task(self._shutdown())

    @property
    def session_dir(self):
        return self._session_dir

    @property
    def repo_dir(self):
        return os.path.join(self._session_dir, "repo")

    @property
    def data_dir(self):
        return os.path.join(self._session_dir, "data")

    @property
    def app_url(self):
        return self._app_url

    def _add_callback(self, callback_name, func):
        self.callbacks[callback_name] = func

    def callback(self, callback_name):
        """A decorator that is used to register a view function for a
        given application command.  This does the same thing as :meth:`add_callback`
        but is intended for decorator usage::
            @app.callback('calc')
            def calc_func():
                return 'Hello World'
        :param callback_name: the command name as string
        """
        def decorator(f):
            self._add_callback(callback_name, f)

            @functools.wraps(f)
            def wrapper(*args, **kwargs):
                f(*args, **kwargs)
            return wrapper
        return decorator

    def handle_message_sync(self, request_msg):
        try:
            state = request_msg.get(STATE, None)
            context = request_msg.get(CONTEXT, None)
            command = request_msg["command"]
            user_api_token = request_msg["api_token"]
            user_public_api = Api(self.server_address, user_api_token, retry_count=5, external_logger=self.logger,
                                  ignore_task_id=self._ignore_task_id)

            if command == STOP_COMMAND:
                self.logger.info("APP receives stop signal from user")
                self.stop_event.set()

            if command == STOP_COMMAND and command not in self.callbacks:
                _default_stop(user_public_api, self.task_id, context, state, self.logger)
                if self._ignore_stop_for_debug is False:
                    #self.stop()
                    asyncio.run_coroutine_threadsafe(self._shutdown(), self.loop)
                    return
                else:
                    self.logger.info("STOP event is ignored ...")
            elif command in AppService.DEFAULT_EVENTS and command not in self.callbacks:
                raise KeyError("App received default command {!r}. Use decorator \"callback\" to handle it."
                               .format(command))
            elif command not in self.callbacks:
                raise KeyError("App received unhandled command {!r}. Use decorator \"callback\" to handle it."
                               .format(command))

            if command == STOP_COMMAND:
                if self._ignore_stop_for_debug is False:
                    self.callbacks[command](api=user_public_api,
                                            task_id=self.task_id,
                                            context=context,
                                            state=state,
                                            app_logger=self.logger)
                    asyncio.run_coroutine_threadsafe(self._shutdown(), self.loop)
                    return
                else:
                    self.logger.info("STOP event is ignored ...")
            else:
                self.callbacks[command](api=user_public_api,
                                        task_id=self.task_id,
                                        context=context,
                                        state=state,
                                        app_logger=self.logger)
        except KeyError as e:
            self.logger.error(e, exc_info=False)
        except Exception as e:
            self.logger.error(traceback.format_exc(), exc_info=True, extra={'exc_str': repr(e)})
            if self._ignore_errors is False:
                self.logger.info("App will be stopped due to error")
                #asyncio.create_task(self._shutdown(error=e))
                asyncio.run_coroutine_threadsafe(self._shutdown(error=e), self.loop)

    def consume_sync(self):
        while True:
            request_msg = self.processing_queue.get()
            to_log = _remove_sensitive_information(request_msg)
            self.logger.debug('FULL_TASK_MESSAGE', extra={'task_msg': to_log})
            #asyncio.run_coroutine_threadsafe(self.handle_message(request_msg), self.loop)
            asyncio.ensure_future(
                self.loop.run_in_executor(self.executor, self.handle_message_sync, request_msg), loop=self.loop
            )

    async def consume(self):
        self.logger.info("Starting consumer")
        asyncio.ensure_future(
            self.loop.run_in_executor(self.executor, self.consume_sync), loop=self.loop
        )

    def publish_sync(self, initial_events=None):
        if initial_events is not None:
            for event_obj in initial_events:
                event_obj["api_token"] = os.environ[API_TOKEN]
                self.processing_queue.put(event_obj)

        for gen_event in self.api.get_endless_stream('GetGeneralEventsStream', api_proto.GeneralEvent, api_proto.Empty()):
            try:
                data = {}
                if gen_event.data is not None and gen_event.data != b'':
                    data = json.loads(gen_event.data.decode('utf-8'))

                event_obj = {REQUEST_ID: gen_event.request_id, **data}
                self.processing_queue.put(event_obj)
            except Exception as error:
                self.logger.warning('App exception: ', extra={"error_message": repr(error)})

        raise ConnectionClosedByServerException('Requests stream to a deployed model closed by the server.')

    async def publish(self, initial_events=None):
        self.logger.info("Starting publisher")
        asyncio.ensure_future(
            self.loop.run_in_executor(self.executor, self.publish_sync, initial_events), loop=self.loop
        )

    def run(self, template_path=None, data=None, state=None, initial_events=None):
        if template_path is None:
            # read config
            config_path = os.path.join(self.repo_dir, os.environ.get("CONFIG_DIR", ""), 'config.json')
            if file_exists(config_path):
                #we are not in debug mode
                config = load_json_file(config_path)
                template_path = config.get('gui_template', None)
                if template_path is None:
                    self.logger.info("there is no gui_template field in config.json")
                else:
                    template_path = os.path.join(self.repo_dir, template_path)

            if template_path is None:
                template_path = os.path.join(os.path.dirname(sys.argv[0]), 'gui.html')

        if not file_exists(template_path):
            self.logger.info("App will be running without GUI", extra={"app_url": self.app_url})
            template = ""
        else:
            with open(template_path, 'r') as file:
                template = file.read()

        self.public_api.app.initialize(self.task_id, template, data, state)
        self.logger.info("Application session is initialized", extra={"app_url": self.app_url})

        try:
            self.loop.create_task(self.publish(initial_events), name="Publisher")
            self.loop.create_task(self.consume(), name="Consumer")
            self.loop.run_forever()
        finally:
            self.loop.close()
            self.logger.info("Successfully shutdown the APP service.")

        if self._error is not None:
            raise self._error

    def stop(self, wait=True):
        #@TODO: add timeout
        if wait is True:
            event_obj = {"command": "stop", "api_token": os.environ[API_TOKEN]}
            self.processing_queue.put(event_obj)
        else:
            self.logger.info('Stop app (force, no wait)', extra={'event_type': EventType.APP_FINISHED})
            #asyncio.create_task(self._shutdown())
            asyncio.run_coroutine_threadsafe(self._shutdown(), self.loop)

    async def _shutdown(self, signal=None, error=None):
        """Cleanup tasks tied to the service's shutdown."""
        if signal:
            self.logger.info(f"Received exit signal {signal.name}...")
        self.logger.info("Nacking outstanding messages")
        tasks = [t for t in asyncio.all_tasks() if t is not
                 asyncio.current_task()]

        [task.cancel() for task in tasks]

        self.logger.info(f"Cancelling {len(tasks)} outstanding tasks")
        await asyncio.gather(*tasks, return_exceptions=True)

        self.logger.info("Shutting down ThreadPoolExecutor")
        self.executor.shutdown(wait=False)

        self.logger.info(f"Releasing {len(self.executor._threads)} threads from executor")
        for thread in self.executor._threads:
            try:
                thread._tstate_lock.release()
            except Exception:
                pass

        self.logger.info(f"Flushing metrics")
        self.loop.stop()

        if error is not None:
            self._error = error