def __init__(self, logger=None, task_id=None, server_address=None, agent_token=None, ignore_errors=False, ignore_task_id=False): self._ignore_task_id = ignore_task_id self.logger = take_with_default(logger, default_logger) self._ignore_errors = ignore_errors self.task_id = take_with_default(task_id, os.environ["TASK_ID"]) self.server_address = take_with_default(server_address, os.environ[SERVER_ADDRESS]) self.agent_token = take_with_default(agent_token, os.environ[AGENT_TOKEN]) self.public_api = Api.from_env(ignore_task_id=self._ignore_task_id) self._app_url = self.public_api.app.get_url(self.task_id) self._session_dir = "/sessions/{}".format(self.task_id) self.api = AgentAPI(token=self.agent_token, server_address=self.server_address, ext_logger=self.logger) self.api.add_to_metadata('x-task-id', str(self.task_id)) self.callbacks = {} self.processing_queue = queue.Queue()#(maxsize=self.QUEUE_MAX_SIZE) self.logger.debug('App is created', extra={"task_id": self.task_id, "server_address": self.server_address}) self._ignore_stop_for_debug = False self._error = None self.stop_event = asyncio.Event() self.executor = concurrent.futures.ThreadPoolExecutor() self.loop = asyncio.get_event_loop() # May want to catch other signals too signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT, signal.SIGQUIT) for s in signals: self.loop.add_signal_handler(s, lambda s=s: asyncio.create_task(self._shutdown(signal=s))) # comment out the line below to see how unhandled exceptions behave self.loop.set_exception_handler(self.handle_exception)
class AppService: NETW_CHUNK_SIZE = 1048576 QUEUE_MAX_SIZE = 2000 # Maximum number of in-flight requests to avoid exhausting server memory. def __init__(self, logger, task_config): self.logger = logger self.server_address = task_config['server_address'] self.api = AgentAPI(token=task_config['agent_token'], server_address=self.server_address, ext_logger=self.logger) self.api.add_to_metadata('x-task-id', str(task_config['task_id'])) self.routes = {} self.thread_pool = concurrent.futures.ThreadPoolExecutor( max_workers=10) self.processing_queue = Queue(maxsize=self.QUEUE_MAX_SIZE) self.logger.debug('Created AgentRPCServicer', extra=task_config) def add_route(self, route, func): self.routes[route] = func def _processing(self): while True: request_msg = self.processing_queue.get(block=True, timeout=None) try: self.routes[request_msg["command"]](request_msg) except Exception as e: self.logger.error(traceback.format_exc(), exc_info=True, extra={'exc_str': str(e)}) def run(self): def seq_inf_wrapped(): function_wrapper(self._processing) # exit if raised process_thread = threading.Thread(target=seq_inf_wrapped, daemon=True) process_thread.start() for gen_event in self.api.get_endless_stream('GetGeneralEventsStream', api_proto.GeneralEvent, api_proto.Empty()): try: data = {} if gen_event.data is not None and gen_event.data != b'': data = json.loads(gen_event.data.decode('utf-8')) event_obj = {REQUEST_ID: gen_event.request_id, **data} self.processing_queue.put(event_obj, block=True) except Exception as error: self.logger.warning('App exception: ', extra={"error_message": str(error)}) raise ConnectionClosedByServerException( 'Requests stream to a deployed model closed by the server.')
def __init__(self, logger, task_config): self.logger = logger self.server_address = task_config['server_address'] self.api = AgentAPI(token=task_config['agent_token'], server_address=self.server_address, ext_logger=self.logger) self.api.add_to_metadata('x-task-id', str(task_config['task_id'])) self.routes = {} self.thread_pool = concurrent.futures.ThreadPoolExecutor( max_workers=10) self.processing_queue = Queue(maxsize=self.QUEUE_MAX_SIZE) self.logger.debug('Created AgentRPCServicer', extra=task_config)
def __init__(self, logger, model_applier: SingleImageInferenceInterface, conn_config, cache): self.logger = logger self.api = AgentAPI(token=conn_config['token'], server_address=conn_config['server_address'], ext_logger=self.logger) self.api.add_to_metadata('x-task-id', conn_config['task_id']) self.model_applier = model_applier self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=10) self.download_queue = Queue(maxsize=self.QUEUE_MAX_SIZE) self.final_processing_queue = Queue(maxsize=self.QUEUE_MAX_SIZE) self.image_cache = cache self._default_inference_mode_config = InfModeFullImage.make_default_config( model_result_suffix=MODEL_RESULT_SUFFIX) self.logger.info('Created AgentRPCServicer', extra=conn_config)
class AgentRPCServicerBase: NETW_CHUNK_SIZE = 1048576 QUEUE_MAX_SIZE = 2000 # Maximum number of in-flight requests to avoid exhausting server memory. def __init__(self, logger, model_applier: SingleImageInferenceInterface, conn_config, cache): self.logger = logger self.api = AgentAPI(token=conn_config['token'], server_address=conn_config['server_address'], ext_logger=self.logger) self.api.add_to_metadata('x-task-id', conn_config['task_id']) self.model_applier = model_applier self.thread_pool = concurrent.futures.ThreadPoolExecutor( max_workers=10) self.download_queue = Queue(maxsize=self.QUEUE_MAX_SIZE) self.final_processing_queue = Queue(maxsize=self.QUEUE_MAX_SIZE) self.image_cache = cache self._default_inference_mode_config = InfModeFullImage.make_default_config( model_result_suffix=MODEL_RESULT_SUFFIX) self.logger.info('Created AgentRPCServicer', extra=conn_config) def _load_image_from_sly(self, req_id, image_hash, src_node_token): self.logger.trace('Will look for image.', extra={ 'request_id': req_id, 'image_hash': image_hash, 'src_node_token': src_node_token }) img_data = self.image_cache.get(image_hash) if img_data is None: img_data_packed = download_image_from_remote( self.api, image_hash, src_node_token, self.logger) img_data = decode_image(img_data_packed) self.image_cache.add(image_hash, img_data) return img_data def _load_arbitrary_image(self, req_id): self.logger.trace('Will load arbitrary image.', extra={'request_id': req_id}) img_data_packed = download_data_from_remote(self.api, req_id, self.logger) img_data = decode_image(img_data_packed) return img_data def _load_data_if_required(self, event_obj): try: req_id = event_obj['request_id'] event_data = event_obj[DATA] request_type = event_data.get(REQUEST_TYPE, INFERENCE) if request_type == INFERENCE: # For inference we need to download an image and add it to the event data. image_hash = event_data.get('image_hash') if image_hash is None: img_data = self._load_arbitrary_image(req_id) else: src_node_token = event_obj['data'].get( 'src_node_token', '') img_data = self._load_image_from_sly( req_id, image_hash, src_node_token) event_data['image_arr'] = img_data self.logger.trace('Input image obtained.', extra={'request_id': req_id}) self.final_processing_queue.put(item=(event_data, req_id)) except Exception as e: res_msg = {} self.logger.error(traceback.format_exc(), exc_info=True, extra={'exc_str': str(e)}) res_msg.update({ 'success': False, 'error': json.dumps(traceback.format_exc()) }) self.thread_pool.submit(function_wrapper_nofail, self._send_data, res_msg, req_id) # skip errors def _send_data(self, out_msg, req_id): self.logger.trace('Will send output data.', extra={'request_id': req_id}) out_bytes = json.dumps(out_msg).encode('utf-8') self.api.put_stream_with_data('SendGeneralEventData', api_proto.Empty, send_from_memory_generator( out_bytes, self.NETW_CHUNK_SIZE), addit_headers={'x-request-id': req_id}) self.logger.trace('Output data is sent.', extra={'request_id': req_id}) def _final_processing(self, in_msg): request_type = in_msg.get(REQUEST_TYPE, INFERENCE) if request_type == INFERENCE: img = in_msg['image_arr'] if len(img.shape) != 3 or img.shape[2] not in [3, 4]: raise RuntimeError( 'Expect 3- or 4-channel image RGB(RGBA) [0..255].') elif img.shape[2] == 4: img = drop_image_alpha_channel(img) return self._do_single_img_inference(img, in_msg) elif request_type == GET_OUT_META: return {'out_meta': self._get_out_meta(in_msg).to_json()} else: raise RuntimeError( 'Unknown request type: {}. Only the following request types are supported: {}' .format(request_type, SUPPORTED_REQUEST_TYPES)) def _do_single_img_inference(self, img, in_msg): raise NotImplementedError() def _get_out_meta(self, in_msg): raise NotImplementedError() def _sequential_final_processing(self): while True: in_msg, req_id = self.final_processing_queue.get(block=True, timeout=None) res_msg = {} try: res_msg.update(self._final_processing(in_msg)) res_msg.update({'success': True}) except Exception as e: self.logger.error(traceback.format_exc(), exc_info=True, extra={'exc_str': str(e)}) res_msg.update({ 'success': False, 'error': json.dumps(traceback.format_exc()) }) self.thread_pool.submit(function_wrapper_nofail, self._send_data, res_msg, req_id) # skip errors def _load_data_loop(self): while True: event_obj = self.download_queue.get(block=True, timeout=None) self._load_data_if_required(event_obj) def run_inf_loop(self): def seq_inf_wrapped(): function_wrapper( self._sequential_final_processing) # exit if raised load_data_thread = threading.Thread(target=self._load_data_loop, daemon=True) load_data_thread.start() inference_thread = threading.Thread(target=seq_inf_wrapped, daemon=True) inference_thread.start() report_agent_rpc_ready() for gen_event in self.api.get_endless_stream('GetGeneralEventsStream', api_proto.GeneralEvent, api_proto.Empty()): try: request_id = gen_event.request_id data = {} if gen_event.data is not None and gen_event.data != b'': data = json.loads(gen_event.data.decode('utf-8')) event_obj = {'request_id': request_id, 'data': data} self.logger.debug('GET_INFERENCE_CALL', extra=event_obj) self.download_queue.put(event_obj, block=True) except Exception as error: self.logger.warning('Inference exception: ', extra={"error_message": str(error)}) res_msg = {'success': False, 'error': json.dumps(str(error))} self.thread_pool.submit(function_wrapper_nofail, self._send_data, res_msg, request_id)
class AppService: NETW_CHUNK_SIZE = 1048576 QUEUE_MAX_SIZE = 2000 # Maximum number of in-flight requests to avoid exhausting server memory. DEFAULT_EVENTS = [STOP_COMMAND, *IMAGE_ANNOTATION_EVENTS] def __init__(self, logger=None, task_id=None, server_address=None, agent_token=None, ignore_errors=False, ignore_task_id=False): self._ignore_task_id = ignore_task_id self.logger = take_with_default(logger, default_logger) self._ignore_errors = ignore_errors self.task_id = take_with_default(task_id, os.environ["TASK_ID"]) self.server_address = take_with_default(server_address, os.environ[SERVER_ADDRESS]) self.agent_token = take_with_default(agent_token, os.environ[AGENT_TOKEN]) self.public_api = Api.from_env(ignore_task_id=self._ignore_task_id) self._app_url = self.public_api.app.get_url(self.task_id) self._session_dir = "/sessions/{}".format(self.task_id) self.api = AgentAPI(token=self.agent_token, server_address=self.server_address, ext_logger=self.logger) self.api.add_to_metadata('x-task-id', str(self.task_id)) self.callbacks = {} self.processing_queue = queue.Queue()#(maxsize=self.QUEUE_MAX_SIZE) self.logger.debug('App is created', extra={"task_id": self.task_id, "server_address": self.server_address}) self._ignore_stop_for_debug = False self._error = None self.stop_event = asyncio.Event() self.executor = concurrent.futures.ThreadPoolExecutor() self.loop = asyncio.get_event_loop() # May want to catch other signals too signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT, signal.SIGQUIT) for s in signals: self.loop.add_signal_handler(s, lambda s=s: asyncio.create_task(self._shutdown(signal=s))) # comment out the line below to see how unhandled exceptions behave self.loop.set_exception_handler(self.handle_exception) def handle_exception(self, loop, context): # context["message"] will always be there; but context["exception"] may not msg = context.get("exception", context["message"]) if isinstance(msg, Exception): #self.logger.error(traceback.format_exc(), exc_info=True, extra={'exc_str': str(msg), 'future_info': context["future"]}) self.logger.error(msg, exc_info=True, extra={'future_info': context["future"]}) else: self.logger.error("Caught exception: {}".format(msg)) self.logger.info("Shutting down...") asyncio.create_task(self._shutdown()) @property def session_dir(self): return self._session_dir @property def repo_dir(self): return os.path.join(self._session_dir, "repo") @property def data_dir(self): return os.path.join(self._session_dir, "data") @property def app_url(self): return self._app_url def _add_callback(self, callback_name, func): self.callbacks[callback_name] = func def callback(self, callback_name): """A decorator that is used to register a view function for a given application command. This does the same thing as :meth:`add_callback` but is intended for decorator usage:: @app.callback('calc') def calc_func(): return 'Hello World' :param callback_name: the command name as string """ def decorator(f): self._add_callback(callback_name, f) @functools.wraps(f) def wrapper(*args, **kwargs): f(*args, **kwargs) return wrapper return decorator def handle_message_sync(self, request_msg): try: state = request_msg.get(STATE, None) context = request_msg.get(CONTEXT, None) command = request_msg["command"] user_api_token = request_msg["api_token"] user_public_api = Api(self.server_address, user_api_token, retry_count=5, external_logger=self.logger, ignore_task_id=self._ignore_task_id) if command == STOP_COMMAND: self.logger.info("APP receives stop signal from user") self.stop_event.set() if command == STOP_COMMAND and command not in self.callbacks: _default_stop(user_public_api, self.task_id, context, state, self.logger) if self._ignore_stop_for_debug is False: #self.stop() asyncio.run_coroutine_threadsafe(self._shutdown(), self.loop) return else: self.logger.info("STOP event is ignored ...") elif command in AppService.DEFAULT_EVENTS and command not in self.callbacks: raise KeyError("App received default command {!r}. Use decorator \"callback\" to handle it." .format(command)) elif command not in self.callbacks: raise KeyError("App received unhandled command {!r}. Use decorator \"callback\" to handle it." .format(command)) if command == STOP_COMMAND: if self._ignore_stop_for_debug is False: self.callbacks[command](api=user_public_api, task_id=self.task_id, context=context, state=state, app_logger=self.logger) asyncio.run_coroutine_threadsafe(self._shutdown(), self.loop) return else: self.logger.info("STOP event is ignored ...") else: self.callbacks[command](api=user_public_api, task_id=self.task_id, context=context, state=state, app_logger=self.logger) except KeyError as e: self.logger.error(e, exc_info=False) except Exception as e: self.logger.error(traceback.format_exc(), exc_info=True, extra={'exc_str': repr(e)}) if self._ignore_errors is False: self.logger.info("App will be stopped due to error") #asyncio.create_task(self._shutdown(error=e)) asyncio.run_coroutine_threadsafe(self._shutdown(error=e), self.loop) def consume_sync(self): while True: request_msg = self.processing_queue.get() to_log = _remove_sensitive_information(request_msg) self.logger.debug('FULL_TASK_MESSAGE', extra={'task_msg': to_log}) #asyncio.run_coroutine_threadsafe(self.handle_message(request_msg), self.loop) asyncio.ensure_future( self.loop.run_in_executor(self.executor, self.handle_message_sync, request_msg), loop=self.loop ) async def consume(self): self.logger.info("Starting consumer") asyncio.ensure_future( self.loop.run_in_executor(self.executor, self.consume_sync), loop=self.loop ) def publish_sync(self, initial_events=None): if initial_events is not None: for event_obj in initial_events: event_obj["api_token"] = os.environ[API_TOKEN] self.processing_queue.put(event_obj) for gen_event in self.api.get_endless_stream('GetGeneralEventsStream', api_proto.GeneralEvent, api_proto.Empty()): try: data = {} if gen_event.data is not None and gen_event.data != b'': data = json.loads(gen_event.data.decode('utf-8')) event_obj = {REQUEST_ID: gen_event.request_id, **data} self.processing_queue.put(event_obj) except Exception as error: self.logger.warning('App exception: ', extra={"error_message": repr(error)}) raise ConnectionClosedByServerException('Requests stream to a deployed model closed by the server.') async def publish(self, initial_events=None): self.logger.info("Starting publisher") asyncio.ensure_future( self.loop.run_in_executor(self.executor, self.publish_sync, initial_events), loop=self.loop ) def run(self, template_path=None, data=None, state=None, initial_events=None): if template_path is None: # read config config_path = os.path.join(self.repo_dir, os.environ.get("CONFIG_DIR", ""), 'config.json') if file_exists(config_path): #we are not in debug mode config = load_json_file(config_path) template_path = config.get('gui_template', None) if template_path is None: self.logger.info("there is no gui_template field in config.json") else: template_path = os.path.join(self.repo_dir, template_path) if template_path is None: template_path = os.path.join(os.path.dirname(sys.argv[0]), 'gui.html') if not file_exists(template_path): self.logger.info("App will be running without GUI", extra={"app_url": self.app_url}) template = "" else: with open(template_path, 'r') as file: template = file.read() self.public_api.app.initialize(self.task_id, template, data, state) self.logger.info("Application session is initialized", extra={"app_url": self.app_url}) try: self.loop.create_task(self.publish(initial_events), name="Publisher") self.loop.create_task(self.consume(), name="Consumer") self.loop.run_forever() finally: self.loop.close() self.logger.info("Successfully shutdown the APP service.") if self._error is not None: raise self._error def stop(self, wait=True): #@TODO: add timeout if wait is True: event_obj = {"command": "stop", "api_token": os.environ[API_TOKEN]} self.processing_queue.put(event_obj) else: self.logger.info('Stop app (force, no wait)', extra={'event_type': EventType.APP_FINISHED}) #asyncio.create_task(self._shutdown()) asyncio.run_coroutine_threadsafe(self._shutdown(), self.loop) async def _shutdown(self, signal=None, error=None): """Cleanup tasks tied to the service's shutdown.""" if signal: self.logger.info(f"Received exit signal {signal.name}...") self.logger.info("Nacking outstanding messages") tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] [task.cancel() for task in tasks] self.logger.info(f"Cancelling {len(tasks)} outstanding tasks") await asyncio.gather(*tasks, return_exceptions=True) self.logger.info("Shutting down ThreadPoolExecutor") self.executor.shutdown(wait=False) self.logger.info(f"Releasing {len(self.executor._threads)} threads from executor") for thread in self.executor._threads: try: thread._tstate_lock.release() except Exception: pass self.logger.info(f"Flushing metrics") self.loop.stop() if error is not None: self._error = error