def start_game(): ''' Main entry point for the application ''' cache_actions() sockets = netutil.bind_sockets(8888) #if process.task_id() == None: # tornado.process.fork_processes(-1, max_restarts = 10) server = HTTPServer(application) server.add_sockets(sockets) io_loop = IOLoop.instance() session_manager = SessionManager.Instance() if process.task_id() == None: scoring = PeriodicCallback(scoring_round, application.settings['ticks'], io_loop = io_loop) session_clean_up = PeriodicCallback(session_manager.clean_up, application.settings['clean_up_timeout'], io_loop = io_loop) scoring.start() session_clean_up.start() try: for count in range(3, 0, -1): logging.info("The game will begin in ... %d" % (count,)) sleep(1) logging.info("Good hunting!") io_loop.start() except KeyboardInterrupt: if process.task_id() == 0: print '\r[!] Shutdown Everything!' session_clean_up.stop() io_loop.stop()
class WSHandler(tornado.websocket.WebSocketHandler): def check_origin(self, origin): return True def open(self): with q_live.mutex: q_live.queue.clear() self.callback = PeriodicCallback(self.send_werte, 1) self.callback.start() print ('Connection open') def send_werte(self): if not q_live.empty(): signals, values = q_live.get() senden = dict(zip(signals,values)) print(senden) json_send = json.dumps(senden) self.write_message(json_send) print(q_live.qsize()) if q_live.qsize() >15: with q_live.mutex: q_live.queue.clear() def on_message(self, empf): print('Daten recievied: ') def on_close(self): print('Connection closed!') self.callback.stop()
class SendWebSocket(tornado.websocket.WebSocketHandler): #on_message -> receive data #write_message -> send data def __init__(self, *args, **keys): self.i = 0 super(SendWebSocket, self).__init__(*args, **keys) def open(self): self.callback = PeriodicCallback(self._send_message, 1) self.callback.start() print "WebSocket opend" def on_message(self, message): print message def _send_message(self): self.i += 1 self.write_message(str(self.i)) if self.i % 20 == 0: self.write_message("\n") def on_close(self): self.callback.stop() print "WebSocket closed"
class SocketHandler(WebSocketHandler): def check_origin(self, origin): """ Overrides the parent method to return True for any request, since we are working without names :returns: bool True """ return True def open(self): logging.info("Connection open from " + self.request.remote_ip) if not self in statusmonitor_open_sockets: statusmonitor_open_sockets.append(self) #http://stackoverflow.com/a/19571205 self.callback = PeriodicCallback(self.send_data, 1000) self.callback.start() start_callback() def send_data(self): self.write_message(data_json) return def on_close(self): self.callback.stop() if self in statusmonitor_open_sockets: statusmonitor_open_sockets.remove(self) stop_callback() def send_update(self): pass
class WebSocketconnectionsHandler(tornado.websocket.WebSocketHandler): def __init__(self, *args, **kwargs): logger.debug("Creating WebSocket connections handler") super(WebSocketconnectionsHandler, self).__init__(*args, **kwargs) # No WebSocket connection yet self.connected = False # We have not counted the connections yet self.connections = 0 # Update the connection count self.update() # Setup periodic callback via Tornado self.periodic_callback = PeriodicCallback(getattr(self, 'update'), 1000) def get_connections(self): self.connections = 0 # Get all connections using psutil conn = psutil.net_connections('inet') if ws.config.CONFIG['PORT'][0] == 'all': # If we need the count for all ports we've got it. for connection in conn: self.connections += 1 else: # Isolate date for the requested ports. for port in ws.config.CONFIG['PORT']: for connection in conn: if connection.laddr[1] == int(port): self.connections += 1 return(self.connections) def update(self): # Save the old number of connections old = self.connections self.get_connections() # Check if the number of connections has changed if old != self.connections: # Send the new data. if self.connected: logger.debug(json.dumps({ "connections": self.get_connections() })) self.write_message(json.dumps({ "connections": self.get_connections() })) def open(self): logger.debug(json.dumps({ "connections": self.get_connections() })) self.write_message(json.dumps({ "connections": self.get_connections() })) # We have a WebSocket connection self.connected = True self.periodic_callback.start() def on_message(self, message): logger.debug(json.dumps({ "connections": self.get_connections() })) self.write_message(json.dumps({ "connections": self.get_connections() })) def on_close(self): logger.debug("Connection closed") # We no longer have a WebSocket connection. self.connected = False self.periodic_callback.stop()
class LoLAPI(object): def __init__(self, client): self.timer = PeriodicCallback(self.status, 1000, IOLoop.instance()) self.client = client self.timer.start() def status(self): self.client.one.update_status(dict( last_updated = datetime.now().strftime("%H:%M:%S %d-%m-%y"), game_stats = db.games_data.count(), players = db.users.count(), full_games = db.games.count(), invalid_games = db.invalid_games.count() )) def set_user(self, name): self.user = User.by_name(name) stats = GameStats.find(dict(summoner = self.user.get_dbref())) games = [Game.find_one(stat['game_id']) for stat in stats] self.client.one.update_games([1, 2, 3, 4, 5, 6, 7]) # self.client.one.update_games(list(stats)) def detach(self): self.timer.stop()
class WebSocketChatHandler(tornado.websocket.WebSocketHandler): def initialize(self): self.clients = [] self.callback = PeriodicCallback(self.update_chat, 500) self.web_gui_user = self.player_manager.get_by_name(self.get_secure_cookie("player")) def open(self, *args): self.clients.append(self) for msg in self.messages_log: self.write_message(msg) self.callback.start() def on_message(self, message): messagejson = json.loads(message) self.messages.append(message) self.messages_log.append(message) self.factory.broadcast("^yellow;<{d}> <^red;{u}^yellow;> {m}".format( d=datetime.now().strftime("%H:%M"), u=self.web_gui_user.name, m=messagejson["message"]), 0, "") def update_chat(self): if len(self.messages) > 0: for message in sorted(self.messages): for client in self.clients: client.write_message(message) del self.messages[0:len(self.messages)] def on_close(self): self.clients.remove(self) self.callback.stop()
class cpustatus(tornado.websocket.WebSocketHandler): #on_message -> receive data #write_message -> send data #index.html def open(self): #self.i = readData() self.i = 0 self.last = 0 self.cpu = PeriodicCallback(self._send_cpu, 500) # self.cpu.start() def on_message(self, message): global MainMotorMax self.i = int(message) MainMotorMax = self.i print message def _send_cpu(self): #self.write_message(str(vmstat()[15])) #self.write_message(str(time.time())) #self.i = readData() if self.i != self.last: self.write_message(str(self.i)) self.last = self.i print self.i # def on_close(self): self.cpu.stop()
class WSHandler(tornado.websocket.WebSocketHandler): def initialize(self): self.values = [[], []] def check_origin(self, origin): return True def open(self): # Send message periodic via socket upon a time interval self.initialize() self.callback = PeriodicCallback(self.send_values, timeInterval) self.callback.start() def send_values(self): MAX_POINTS = 30 # Generates random values to send via websocket for val in self.values: if len(val) < MAX_POINTS: val.append(randint(1, 10)) else: val.pop(0) val.append(randint(1, 10)) # self.values1 = [randint(1,10) for i in range(100)] message = {"Channel0": self.values[0], "Channel1": self.values[1]} # self.write_message(message) message = {"DataInfo": [{"id": 40, "sname": "SOG"}]} self.write_message(message) def on_message(self, message): pass def on_close(self): self.callback.stop()
class EventedStatsCollector(StatsCollector): """ Stats Collector which allows to subscribe to value changes. Update notifications are throttled: interval between updates is no shorter than ``accumulate_time``. It is assumed that stat keys are never deleted. """ accumulate_time = 0.1 # value is in seconds def __init__(self, crawler): super(EventedStatsCollector, self).__init__(crawler) self.signals = SignalManager(self) self._changes = {} self._task = PeriodicCallback(self.emit_changes, self.accumulate_time*1000) self._task.start() # FIXME: this is ugly self.crawler = crawler # used by ArachnadoCrawlerProcess def emit_changes(self): if self._changes: changes, self._changes = self._changes, {} self.signals.send_catch_log(stats_changed, changes=changes) def open_spider(self, spider): super(EventedStatsCollector, self).open_spider(spider) self._task.start() def close_spider(self, spider, reason): super(EventedStatsCollector, self).close_spider(spider, reason) self._task.stop()
class WebSocket(tornado.websocket.WebSocketHandler): waiters = set() # multi clients connect OK wdata = "" def open(self): print("open websocket connection") WebSocket.waiters.add(self) # client add self.callback = PeriodicCallback(self._send_message, 30000) # time out taisaku self.callback.start() def on_close(self): WebSocket.waiters.remove(self) # client remove self.callback.stop() print("close websocket connection") def on_message(self, message): WebSocket.wdata = message WebSocket.send_updates(message) @classmethod def send_updates(cls, message): # this method is singleton print(message + ":connection=" + str(len(cls.waiters))) for waiter in cls.waiters: try: waiter.write_message(message) except: print("Error sending message", exc_info=True) # TIME OUT BOUSHI CALL BACK 30Sec def _send_message(self): self.write_message("C:POLLING")
class WebSocketGame(WebSocketHandler): def open(self): self.game_data = {} self.initialize_game() self.write_message(self.game_data) def on_message(self, message): message = json.loads(message) if message["type"] == "login": self.game_name = message["name"] self.game_id = message["game_id"] self.loop_callback = PeriodicCallback(self.do_loop, 5000) else: self.handle_message(message) def on_close(self): self.loop_callback.stop() pass def update_status(self, status): if status not in ("S", "I", "U", "F"): # Start, InProgress, Succesful, Fail return # Let's try not to hit the status API with bad values. url = "http://localhost:8080/private_api/gametask/{}/{}/{}".format(self.game_name, self.game_id, status) request = HTTPRequest(url=url) http = AsyncHTTPClient() http.fetch(request, self.callback) def callback(self, response): # Catch any errors. print "Callback fired." print "HTTP Code: {}".format(response.code)
class SendWebSocketHandler(tornado.websocket.WebSocketHandler): # on_message recieve data # write_message send data def open(self): self.callback = PeriodicCallback(self._send_message, 10000) self.callback.start() print("[START] WebSocket") def on_message(self, message): print("[START] WebSocket on_message") print(message) def _send_message(self): cur = DB.execute("SELECT * FROM lm35dz ORDER BY YMDHHMM DESC") rec = cur.fetchone() send_value = "" if rec == None: send_value = "Data Nothing" else: send_value = "%s %s" % (rec[0], rec[1]) self.write_message(send_value) def on_close(self): self.callback.stop() print("[ENDED] WebSocket")
class WebSocket(tornado.websocket.WebSocketHandler): def check_origin(self, origin): return True def on_message(self, message): """Evaluates the function pointed to by json-rpc.""" # Start an infinite loop when this is called if message == "read_camera": self.camera_loop = PeriodicCallback(self.loop, 10) self.camera_loop.start() # Extensibility for other methods else: print("Unsupported function: " + message) def loop(self): """Sends camera images in an infinite loop.""" bio = io.BytesIO() if args.use_usb: _, frame = camera.read() img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) img.save(bio, "JPEG") else: camera.capture(bio, "jpeg", use_video_port=True) try: self.write_message(base64.b64encode(bio.getvalue())) except tornado.websocket.WebSocketClosedError: self.camera_loop.stop()
def broadcast_players(g_id): global pcb2 global cur_g_id data_dict = {} if pcb2 is None: cur_g_id = g_id pcb2 = PeriodicCallback(lambda: broadcast_players(g_id), 4000) pcb2.start() elif cur_g_id != g_id: cur_g_id = g_id pcb2.stop() pcb2 = PeriodicCallback(lambda: broadcast_players(g_id), 4000) pcb2.start() g_list = ww_redis_db.keys(g_id+"*") for v in g_list: v = v.decode("utf-8") if len(g_list) > 0: # find game with least spaces for g_key in g_list: game = ww_redis_db.hgetall(g_key) game = {k.decode('utf8'): v.decode('utf8') for k, v in game.items()} #convert from byte array to string dict players = game['players'].split("|") # obtain players in current game in the form of uuid data_dict[g_key.decode("utf-8")] = str(len(players)) data_dict["channel"] = "player_list" publish_data("player_list:"+g_id, data_dict) return data_dict
class SubscriberConnection(ConnectionMixin, SockJSConnection): pub_sub = None user = None def __init__(self, session): super(SubscriberConnection, self).__init__(session) self.session_store = session_store(self) self.pub_sub = get_subscription_provider() def _close_invalid_origin(self): self.close(4000, message='Invalid origin') def on_open(self, request): if not set_origin_connection(request, self): self._close_invalid_origin() return super(SubscriberConnection, self).on_open(request) if is_heartbeat_enabled(): self.periodic_callback = PeriodicCallback(self.send_heartbeat, get_heartbeat_frequency()) self.periodic_callback.start() def send_heartbeat(self): self.send({'heartbeat': '1'}) def on_heartbeat(self): self.session_store.refresh_all_keys() def on_close(self): if hasattr(self, 'periodic_callback'): self.periodic_callback.stop() self.pub_sub.close(self) def on_message(self, data): if not test_origin(self): self._close_invalid_origin() try: data = self.to_json(data) if data == {'heartbeat': '1'}: self.on_heartbeat() return middleware_classes = getattr(settings, 'SWAMP_DRAGON_MIDDLEWARE_CLASSES', None) if middleware_classes: for middleware_class in middleware_classes: middleware_path, middleware_name = middleware_classes[0].rsplit('.', 1) middleware = getattr(import_module(middleware_path), middleware_name) middleware().process_request(self, data) handler = route_handler.get_route_handler(data['route']) handler(self).handle(data) except Exception as e: self.abort_connection() raise e def abort_connection(self): self.close(code=3001, message='Connection aborted') def close(self, code=3000, message='Connection closed'): self.session.close(code, message)
class ProgressWidget(Progress): """ ProgressBar that uses an IPython ProgressBar widget for the notebook See Also -------- progress: User function Progress: Super class with most of the logic TextProgressBar: Text version suitable for the console """ def __init__(self, keys, scheduler=None, minimum=0, dt=0.1, complete=False): keys = {k.key if hasattr(k, 'key') else k for k in keys} self.setup_pre(keys, scheduler, minimum, dt, complete) def clear_errors(errors): for k in errors: self.task_erred(None, k, None, True) if self.scheduler.loop._thread_ident == threading.current_thread().ident: errors = self.setup(keys, complete) else: errors = sync(self.scheduler.loop, self.setup, keys, complete) self.pc = PeriodicCallback(self._update, 1000 * self._dt) from ipywidgets import FloatProgress self.bar = FloatProgress(min=0, max=1, description='0.0s') self.widget = self.bar clear_errors(errors) self.pc.start() def setup(self, keys, complete): errors = Progress.setup(self, keys, complete) return errors def _ipython_display_(self, **kwargs): return self.widget._ipython_display_(**kwargs) def _start(self): return self._update() def stop(self, exception=None, key=None): Progress.stop(self, exception, key=None) with ignoring(AttributeError): self.pc.stop() self._update() if exception: self.bar.bar_style = 'danger' self.bar.value = 1.0 elif not self.keys: self.bar.bar_style = 'success' def _update(self): ntasks = len(self.all_keys) ndone = ntasks - len(self.keys) self.bar.value = ndone / ntasks if ntasks else 1.0 self.bar.description = format_time(self.elapsed)
class BrowserSession(object): def __init__(self, driver_name, ioloop, timeout=5): self._driver = _DRIVERS[driver_name]() self._ioloop = ioloop self._out_queue = Queue() self._in_queue = Queue() self._timeout = timeout keyword_arguments = { "in_queue": self._in_queue, "out_queue": self._out_queue } self._thread = threading.Thread( target=self._on_start, kwargs=keyword_arguments) self._thread.daemon = True def start(self): self._driver.set_page_load_timeout(self._timeout) self._thread.start() return self._driver def stop(self): self._in_queue.put("stop") self._thread.join() self._driver.delete_all_cookies() if not self._out_queue.empty(): message = self._out_queue.get() raise IOLoopException("Error from IOLoop thread: %s" % (message)) def _on_start(self, in_queue, out_queue): self._start_time = time.time() self._periodic_callback = PeriodicCallback( self._on_cycle, 100, io_loop=self._ioloop) self._periodic_callback.start() self._ioloop.start() def _on_cycle(self): if not self._in_queue.empty(): message = self._in_queue.get() self._on_stop() if message != "stop": self._out_queue.put("unknown message %s" % (message)) if time.time() - self._start_time > self._timeout: self._ioloop.stop() self._out_queue.put("timeout") def _on_stop(self): self._periodic_callback.stop() self._ioloop.stop() def __enter__(self): return self.start() def __exit__(self, *args, **kwargs): self.stop()
class WSHandler(tornado.websocket.WebSocketHandler): # Websocket connection has opened with client. # Run loop every 5 milliseconds def open(self): self.callback = PeriodicCallback(self.loop, 5) self.callback.start() global hands self.lastHands = copy.deepcopy(hands) print "Connection opened with client." print "Wave to start hand tracking." # Websocket connection closed with client # Stop loop from running. def on_close(self): self.callback.stop() print "Connection closed with client." # Loop to handle input from OpenNI / Kinect # and send information to client def loop(self): global hands global tracking global q # Poll for updates from OpenNI / Kinect tracking.context.wait_any_update_all() # If gesture is in the queue, send message to client. while not q.empty(): self.write_message(q.get()) # Send hand coordinates to client. if hands != self.lastHands: self.lastHands = copy.deepcopy(hands) self.send_xy() # Hand movement detected, so send message to client. def send_xy(self): global hands message = hands message["type"] = "hands" self.write_message(message) def on_message(self, message): pass # Allow cross origin requests def check_origin(self, origin): return True
class SendWebSocket(tornado.websocket.WebSocketHandler): #on_message -> receive data #write_message -> send data #index.html def open(self): ''' tty = "/dev/ttyACM0" if os.path.exists(tty): self.s = serial.Serial(tty,115200) #self.s.open() enableMotor(self.s, 20) setMotor(self.s, 20, 1, 0 ,0) ''' ''' self.s = mraa.I2c(1) self.s.address(0x20) self.s.writeReg(0x03,1) self.s.writeReg(0x04,1) self.s.writeReg(0x06,0) ''' self.i = 0 self.callback = PeriodicCallback(self._send_message, 1000) # self.callback.start() if(os.path.exists("/dev/video0") and os.system("pidof mjpg_streamer") ==256 ): #os.system(' /home/root/mjpg_streamer/mjpg_streamer -i "./input_uvc.so -d /dev/video0 -r 320x240 -f 15 -n" -o "./output_http.so -p 9000 -w ./www" > /dev/null 2>&1 &') os.system('/home/root/mjpg-streamer/0start.sh') print "WebSocket opened" # def on_message(self, message): ''' if message == "A key : 1": setMotorDuty(self.s, 20, 50) else: setMotorDuty(self.s, 20, 0) ''' ''' if message == "A key : 1": self.s.writeReg(0x06,50) else: self.s.writeReg(0x06,0) ''' print message def _send_message(self): self.i += 1 self.write_message(str(self.i)) # def on_close(self): self.callback.stop() print "WebSocket closed"
class WebSocketTest(WebSocketHandler): def open(self): print "Opened websocket." self.game_data = {} self.loop_callback = PeriodicCallback(self.do_loop, 5000) self.game_data["wood"] = 10 self.game_data["heat"] = 0 self.game_data["progress"] = 0 self.write_json(self.game_data) self.loop_callback.start() def on_message(self, message): print "Message: {}".format(message) message = json.loads(message) if message["type"] == "add_1": self.game_data["wood"] += 1 self.write_json(self.game_data) if message["type"] == "add_5": self.game_data["wood"] += 5 self.write_json(self.game_data) def do_loop(self): if self.game_data["wood"] < 1: self.turns_no_wood += 1 self.game_data["heat"] -= 5 * self.turns_no_wood if self.game_data["heat"] < 1: self.game_data["heat"] = 0 else: self.turns_no_wood = 0 self.game_data["heat"] += 5 * self.game_data["wood"] self.game_data["wood"] -= self.game_data["heat"] / 50 if self.game_data["wood"] < 1: self.game_data["wood"] = 0 if self.game_data["heat"] > 2000: self.game_data["progress"] += 1 if self.game_data["heat"] > 2500: print "Heat failure." self.game_data["fail"] = True if self.game_data["progress"] > 9: print "Success." self.game_data["success"] = True self.write_json(self.game_data) def on_close(self): print "Closed websocket." self.loop_callback.stop() def write_json(self, message): self.write_message(json.dumps(message))
class PoolQueue(object): def __init__(self, *args, **kwargs): super(PoolQueue, self).__init__() from django.conf import settings as django_settings from signalqueue.worker import backends from signalqueue import SQ_RUNMODES as runmodes self.active = kwargs.get("active", True) self.halt = kwargs.get("halt", False) self.interval = 1 self.queue_name = kwargs.get("queue_name", "default") self.runmode = runmodes["SQ_ASYNC_DAEMON"] self.queues = backends.ConnectionHandler(django_settings.SQ_QUEUES, self.runmode) self.signalqueue = self.queues[self.queue_name] self.signalqueue.runmode = self.runmode # use interval from the config if it exists interval = kwargs.get("interval", self.signalqueue.queue_interval) if interval is not None: self.interval = interval if self.interval > 0: if self.halt: self.shark = PeriodicCallback(self.cueball_scratch, self.interval * 10) else: self.shark = PeriodicCallback(self.cueball, self.interval * 10) if self.active: self.shark.start() def stop(self): self.active = False self.shark.stop() def rerack(self): self.active = True self.shark.start() def cueball(self): # logg.info("Dequeueing signal...") with self.signalqueue.log_exceptions(): self.signalqueue.dequeue() def cueball_scratch(self): # logg.info("Dequeueing signal...") with self.signalqueue.log_exceptions(): self.signalqueue.dequeue() if self.signalqueue.count() < 1: print "Queue exhausted, exiting..." raise KeyboardInterrupt
class SubscriberConnection(ConnectionMixin, SockJSConnection): pub_sub = None def __init__(self, session): super(SubscriberConnection, self).__init__(session) self.session_store = session_store(self) self.pub_sub = get_subscription_provider() def _close_invalid_origin(self): self.close(4000, message='Invalid origin') def on_open(self, request): if not set_origin_connection(request, self): self._close_invalid_origin() return super(SubscriberConnection, self).on_open(request) if is_heartbeat_enabled(): self.periodic_callback = PeriodicCallback(self.send_heartbeat, get_heartbeat_frequency()) self.periodic_callback.start() def send_heartbeat(self): self.send({'heartbeat': '1'}) def on_heartbeat(self): self.session_store.refresh_all_keys() def on_close(self): if hasattr(self, 'periodic_callback'): self.periodic_callback.stop() self.pub_sub.close(self) def on_message(self, data): if not test_origin(self): self._close_invalid_origin() try: data = self.to_json(data) if data == {'heartbeat': '1'}: self.on_heartbeat() return handler = route_handler.get_route_handler(data['route']) handler(self).handle(data) except Exception as e: self.abort_connection() raise e def abort_connection(self): self.close(code=3001, message='Connection aborted') def close(self, code=3000, message='Connection closed'): self.session.close(code, message)
class ProcessStatsMonitor(object): """ A class which emits process stats periodically """ signal_updated = object() def __init__(self, interval=1.0): self.signals = SignalManager(self) self.process = psutil.Process(os.getpid()) self.interval = interval self._task = PeriodicCallback(self._emit, self.interval*1000) self._recent = {} def start(self): # yappi.start() self._task.start() def stop(self): self._task.stop() # stats = yappi.get_func_stats() # stats.sort('tsub', 'desc') # with open("func-stats.txt", 'wt') as f: # stats.print_all(f, columns={ # 0: ("name", 80), # 1: ("ncall", 10), # 2: ("tsub", 8), # 3: ("ttot", 8), # 4: ("tavg",8) # }) # # pstats = yappi.convert2pstats(stats) # pstats.dump_stats("func-stats.prof") def get_recent(self): return self._recent def _emit(self): cpu_times = self.process.cpu_times() ram_usage = self.process.memory_info() stats = { 'ram_percent': self.process.memory_percent(), 'ram_rss': ram_usage.rss, 'ram_vms': ram_usage.vms, 'cpu_percent': self.process.cpu_percent(), 'cpu_time_user': cpu_times.user, 'cpu_time_system': cpu_times.system, 'num_fds': self.process.num_fds(), 'context_switches': self.process.num_ctx_switches(), 'num_threads': self.process.num_threads(), 'server_time': int(time.time()*1000), } self._recent = stats self.signals.send_catch_log(self.signal_updated, stats=stats)
class WSHandler(tornado.websocket.WebSocketHandler): def open(self): self.callback = PeriodicCallback(self.send_hello, 2000) self.callback.start() def send_hello(self): self.write_message(json.dumps(ws_push_data)) def on_message(self, message): pass def on_close(self): self.callback.stop()
class WSHandler(tornado.websocket.WebSocketHandler): def open(self): self.callback = PeriodicCallback(self.send_hello, 10) self.callback.start() def send_hello(self): self.write_message("Websocket") def on_message(self, message): print "Recieved Websocket Reply {}".format(message) def on_close(self): self.callback.stop()
class AsyncPopen(object): '''Asynchronous version of :class:`subprocess.Popen`.''' def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs self.on_output = Event() self.on_end = Event() def run(self): self.ioloop = IOLoop.instance() (master_fd, slave_fd) = pty.openpty() # make stdout, stderr non-blocking fcntl.fcntl(master_fd, fcntl.F_SETFL, fcntl.fcntl(master_fd, fcntl.F_GETFL) | os.O_NONBLOCK) self.master_fd = master_fd self.master = os.fdopen(master_fd) # listen to stdout, stderr self.ioloop.add_handler(master_fd, self._handle_subprocess_stdout, self.ioloop.READ) slave = os.fdopen(slave_fd) self.kwargs["stdout"] = slave self.kwargs["stderr"] = slave self.kwargs["close_fds"] = True self.pipe = subprocess.Popen(*self.args, **self.kwargs) self.stdin = self.pipe.stdin # check for process exit self.wait_callback = PeriodicCallback(self._wait_for_end, 250) self.wait_callback.start() def _handle_subprocess_stdout(self, fd, events): if not self.master.closed and (events & IOLoop._EPOLLIN) != 0: data = self.master.read() self.on_output(data) self._wait_for_end(events) def _wait_for_end(self, events=0): self.pipe.poll() if self.pipe.returncode is not None or \ (events & tornado.ioloop.IOLoop._EPOLLHUP) > 0: self.wait_callback.stop() self.master.close() self.ioloop.remove_handler(self.master_fd) self.on_end(self.pipe.returncode)
class SendWebSocket(tornado.websocket.WebSocketHandler): num = 0 #インクリメントする数字 conn = 0 #接続中のホスト数 @classmethod def update_num(cls): cls.num += 1 @classmethod def inclement_conn(cls): cls.conn += 1 @classmethod def decrement_conn(cls): cls.conn -= 1 @classmethod def get_conn(cls): return cls.conn def check_origin(self, origin): return True def open(self): self.callback = PeriodicCallback(self._send_message, 400)#メッセージをクライアントに送るコールバック self.callback.start() self.increment_callback = PeriodicCallback(self.update_num,400)#値をインクリメントするコールバック self.inclement_conn() if self.get_conn() == 1 : self.increment_callback.start()#接続ホスト数が1になった時インクリメントをスタートする print('WebSocket opendやで') def on_message(self, message): print(message) def _send_message(self): self.write_message(str(self.num) + ' 接続中のホスト数=' + str(self.conn)) def on_close(self): self.callback.stop() self.decrement_conn() if self.get_conn() == 0: self.increment_callback.stop()#接続ホストが0になったらインクリメントをストップ print('WebSocket closed')
class TornadoTelemetryManager(TelemetryManager): # pylint: disable=W0612 def __init__(self, ioloop): TelemetryManager.__init__(self) self.ioloop = ioloop self._timer = PeriodicCallback( stack_context.wrap(self._start_clean_up_timer), self.CLEAN_UP_INTERVAL * self.CLEAN_UP_INTERVAL_MULTIPLIER, self.ioloop) self._timer.start() @tornado.gen.coroutine def _start_clean_up_timer(self): self.clean_up_telemetry_data() def _stop_clean_up_timer(self): self._timer.stop()
class SpitRandomStuff(object): def __init__(self, stream, address): log.info('Received connection from %s', address) self.address = address self.stream = stream self.stream.set_close_callback(self._on_close) self.writer = PeriodicCallback(self._random_stuff, 500) self.writer.start() def _on_close(self): log.info('Closed connection from %s', self.address) self.writer.stop() def _random_stuff(self): output = os.urandom(60) self.stream.write(md5(output).hexdigest() + "\n")
class ServerNewHandler(WebSocketBaseHandler): def on_message(self, message): self.params.update(json_loads(message)) # 参数认证 try: args = ['cluster_id', 'name', 'public_ip', 'username', 'passwd'] self.guarantee(*args) for i in args[1:]: self.params[i] = self.params[i].strip() validate_ip(self.params['public_ip']) self.params.update(self.get_lord()) self.params.update({'owner': self.current_user['id']}) except Exception as e: self.write_message(str(e)) self.close() return IOLoop.current().spawn_callback( callback=self.handle_msg) # on_message不能异步, 要实现异步需spawn_callback @coroutine def handle_msg(self): is_deploying = self.redis.hget(DEPLOYING, self.params['public_ip']) is_deployed = self.redis.hget(DEPLOYED, self.params['public_ip']) # 通知主机添加失败,后续需要将主机添加失败原因进行抽象分类告知用户 message = { 'owner': self.params.get('owner'), 'ip': self.params['public_ip'], 'tip': '{}'.format( self.params.get('lord') if self.params.get('form') == FORM_COMPANY else 0) } if is_deploying: reason = '%s 正在部署' % self.params['public_ip'] self.write_message(reason) self.write_message('failure') message['reason'] = reason yield self.message_service.notify_server_add_failed(message) return if is_deployed: reason = '%s 之前已部署' % self.params['public_ip'] self.write_message(reason) self.write_message('failure') message['reason'] = reason yield self.message_service.notify_server_add_failed(message) return # 保存到redis之前加密 passwd = self.params['passwd'] self.params['passwd'] = Aes.encrypt(passwd) self.redis.hset(DEPLOYING, self.params['public_ip'], json.dumps(self.params)) self.period = PeriodicCallback(self.check, 3000) # 设置定时函数, 3秒 self.period.start() self.params.update({ 'passwd': passwd, 'cmd': MONITOR_CMD, 'rt': True, 'out_func': self.write_message }) _, err = yield self.server_service.remote_ssh(self.params) err = [e for e in err if not re.search(r'symlink|resolve host', e)] # 忽略某些错误 # 部署失败 if err: if err[0] == 'Authentication failed.': reason = '认证失败' self.write_message(reason) message['reason'] = reason yield self.message_service.notify_server_add_failed(message) self.write_message('failure') self.period.stop() self.close() self.redis.hdel(DEPLOYING, self.params['public_ip']) def check(self): ''' 检查主机是否上报信息 ''' result = self.redis.hget(DEPLOYED, self.params['public_ip']) if result: self.write_message('success') self.period.stop() self.close() def on_close(self): if hasattr(self, 'period'): self.period.stop()
class MetadataStorage: def __init__(self, server, gc_path, database): self.server = server database.register_local_namespace(METADATA_NAMESPACE) self.mddb = database.wrap_namespace(METADATA_NAMESPACE, parse_keys=False) self.pending_requests = {} self.events = {} self.busy = False self.gc_path = gc_path self.prune_cb = PeriodicCallback(self.prune_metadata, METADATA_PRUNE_TIME) def update_gcode_path(self, path): if path == self.gc_path: return self.mddb.clear() self.gc_path = path if not self.prune_cb.is_running(): self.prune_cb.start() def close(self): self.prune_cb.stop() def get(self, key, default=None): return self.mddb.get(key, default) def __getitem__(self, key): return self.mddb[key] def prune_metadata(self): for fname in list(self.mddb.keys()): fpath = os.path.join(self.gc_path, fname) if not os.path.exists(fpath): del self.mddb[fname] logging.info(f"Pruned file: {fname}") continue def _has_valid_data(self, fname, fsize, modified): mdata = self.mddb.get(fname, {'size': "", 'modified': 0}) return mdata['size'] == fsize and mdata['modified'] == modified def remove_file(self, fname): try: del self.mddb[fname] except Exception: pass def parse_metadata(self, fname, fsize, modified, notify=False): evt = Event() if fname in self.pending_requests or \ self._has_valid_data(fname, fsize, modified): # request already pending or not necessary evt.set() return evt self.pending_requests[fname] = (fsize, modified, notify, evt) if self.busy: return evt self.busy = True IOLoop.current().spawn_callback(self._process_metadata_update) return evt async def _process_metadata_update(self): while self.pending_requests: fname, (fsize, modified, notify, evt) = \ self.pending_requests.popitem() if self._has_valid_data(fname, fsize, modified): evt.set() continue retries = 3 while retries: try: await self._run_extract_metadata(fname, notify) except Exception: logging.exception("Error running extract_metadata.py") retries -= 1 else: break else: self.mddb[fname] = {'size': fsize, 'modified': modified} logging.info( f"Unable to extract medatadata from file: {fname}") evt.set() self.busy = False async def _run_extract_metadata(self, filename, notify): # Escape single quotes in the file name so that it may be # properly loaded filename = filename.replace("\"", "\\\"") cmd = " ".join([ sys.executable, METADATA_SCRIPT, "-p", self.gc_path, "-f", f"\"{filename}\"" ]) shell_command = self.server.lookup_plugin('shell_command') scmd = shell_command.build_shell_command(cmd, log_stderr=True) result = await scmd.run_with_response(timeout=10.) if result is None: raise self.server.error(f"Metadata extraction error") try: decoded_resp = json.loads(result.strip()) except Exception: logging.debug(f"Invalid metadata response:\n{result}") raise path = decoded_resp['file'] metadata = decoded_resp['metadata'] if not metadata: # This indicates an error, do not add metadata for this raise self.server.error("Unable to extract metadata") self.mddb[path] = dict(metadata) metadata['filename'] = path if notify: self.server.send_event("file_manager:metadata_update", metadata)
class State(BaseState): NAME = "Redis" OK_RESPONSE = "OK" def __init__(self, *args, **kwargs): super(State, self).__init__(*args, **kwargs) self.host = None self.port = None self.db = None self.client = None self.connection_check = None def initialize(self): settings = self.config.get('settings', {}) self.host = settings.get("host", "localhost") self.port = settings.get("port", 6379) self.db = settings.get("db", 0) self.client = toredis.Client(io_loop=self.io_loop) self.client.state = self self.connection_check = PeriodicCallback(self.check_connection, 1000) self.connect() logger.info("Redis State initialized") def on_select(self, res): if res != self.OK_RESPONSE: logger.error("state select database: {0}".format(res)) def connect(self): """ Connect to Redis. Do not even try to connect if State is faked. """ if self.fake: return try: self.client.connect(host=self.host, port=self.port) except Exception as e: logger.error("error connecting to Redis server: %s" % (str(e))) else: if self.db and isinstance(self.db, int): self.client.select(self.db, callback=self.on_select) self.connection_check.stop() self.connection_check.start() def check_connection(self): if not self.client.is_connected(): logger.info('reconnecting to Redis') self.connect() @staticmethod def get_presence_set_key(project_id, namespace, channel): return "centrifuge:presence:set:%s:%s:%s" % (project_id, namespace, channel) @coroutine def add_presence(self, project_id, namespace, channel, uid, user_info, presence_timeout=None): """ Add user's presence with appropriate expiration time. Must be called when user subscribes on channel. """ if self.fake: raise Return((True, None)) now = int(time.time()) expire_at = now + (presence_timeout or self.presence_timeout) hash_key = self.get_presence_hash_key(project_id, namespace, channel) set_key = self.get_presence_set_key(project_id, namespace, channel) try: yield Task(self.client.zadd, set_key, {uid: expire_at}) yield Task(self.client.hset, hash_key, uid, json_encode(user_info)) except StreamClosedError as e: raise Return((None, e)) else: raise Return((True, None)) @coroutine def remove_presence(self, project_id, namespace, channel, uid): """ Remove user's presence from Redis. Must be called on disconnects of any kind. """ if self.fake: raise Return((True, None)) hash_key = self.get_presence_hash_key(project_id, namespace, channel) set_key = self.get_presence_set_key(project_id, namespace, channel) try: yield Task(self.client.hdel, hash_key, uid) yield Task(self.client.zrem, set_key, uid) except StreamClosedError as e: raise Return((None, e)) else: raise Return((True, None)) @coroutine def get_presence(self, project_id, namespace, channel): """ Get presence for channel. """ if self.fake: raise Return((None, None)) now = int(time.time()) hash_key = self.get_presence_hash_key(project_id, namespace, channel) set_key = self.get_presence_set_key(project_id, namespace, channel) try: expired_keys = yield Task(self.client.zrangebyscore, set_key, 0, now) if expired_keys: yield Task(self.client.zremrangebyscore, set_key, 0, now) yield Task(self.client.hdel, hash_key, [x.decode() for x in expired_keys]) data = yield Task(self.client.hgetall, hash_key) except StreamClosedError: raise Return((None, 'presence unavailable')) else: raise Return((dict_from_list(data), None)) @coroutine def add_history_message(self, project_id, namespace, channel, message, history_size=None): """ Add message to channel's history. Must be called when new message has been published. """ if self.fake: raise Return((True, None)) history_size = history_size or self.history_size list_key = self.get_history_list_key(project_id, namespace, channel) try: yield Task(self.client.lpush, list_key, json_encode(message)) yield Task(self.client.ltrim, list_key, 0, history_size - 1) except StreamClosedError as e: raise Return((None, e)) else: raise Return((True, None)) @coroutine def get_history(self, project_id, namespace, channel): """ Get a list of last messages for channel. """ if self.fake: raise Return((None, None)) history_list_key = self.get_history_list_key(project_id, namespace, channel) try: data = yield Task(self.client.lrange, history_list_key, 0, -1) except StreamClosedError: raise Return((None, self.application.INTERNAL_SERVER_ERROR)) else: raise Return(([json_decode(x.decode()) for x in data], None))
class Authorization: def __init__(self, config): self.server = config.get_server() self.login_timeout = config.getint('login_timeout', 90) database = self.server.lookup_component('database') database.register_local_namespace('authorized_users', forbidden=True) self.users = database.wrap_namespace('authorized_users') api_user = self.users.get(API_USER, None) if api_user is None: self.api_key = uuid.uuid4().hex self.users[API_USER] = { 'username': API_USER, 'api_key': self.api_key, 'created_on': time.time() } else: self.api_key = api_user['api_key'] self.trusted_users = {} self.oneshot_tokens = {} self.permitted_paths = set() # Get allowed cors domains self.cors_domains = [] cors_cfg = config.get('cors_domains', "").strip() cds = [d.strip() for d in cors_cfg.split('\n') if d.strip()] for domain in cds: bad_match = re.search(r"^.+\.[^:]*\*", domain) if bad_match is not None: raise config.error( f"Unsafe CORS Domain '{domain}'. Wildcards are not" " permitted in the top level domain.") self.cors_domains.append( domain.replace(".", "\\.").replace("*", ".*")) # Get Trusted Clients self.trusted_ips = [] self.trusted_ranges = [] self.trusted_domains = [] trusted_clients = config.get('trusted_clients', "") trusted_clients = [ c.strip() for c in trusted_clients.split('\n') if c.strip() ] for val in trusted_clients: # Check IP address try: tc = ipaddress.ip_address(val) except ValueError: pass else: self.trusted_ips.append(tc) continue # Check ip network try: tc = ipaddress.ip_network(val) except ValueError: pass else: self.trusted_ranges.append(tc) continue # Check hostname self.trusted_domains.append(val.lower()) t_clients = "\n".join([str(ip) for ip in self.trusted_ips] + [str(rng) for rng in self.trusted_ranges] + self.trusted_domains) c_domains = "\n".join(self.cors_domains) logging.info(f"Authorization Configuration Loaded\n" f"Trusted Clients:\n{t_clients}\n" f"CORS Domains:\n{c_domains}") self.prune_handler = PeriodicCallback(self._prune_conn_handler, PRUNE_CHECK_TIME) self.prune_handler.start() # Register Authorization Endpoints self.permitted_paths.add("/access/login") self.permitted_paths.add("/access/refresh_jwt") self.server.register_endpoint("/access/login", ['POST'], self._handle_login) self.server.register_endpoint("/access/logout", ['POST'], self._handle_logout) self.server.register_endpoint("/access/refresh_jwt", ['POST'], self._handle_refresh_jwt) self.server.register_endpoint("/access/user", ['GET', 'POST', 'DELETE'], self._handle_user_request) self.server.register_endpoint("/access/user/password", ['POST'], self._handle_password_reset) self.server.register_endpoint("/access/api_key", ['GET', 'POST'], self._handle_apikey_request, protocol=['http']) self.server.register_endpoint("/access/oneshot_token", ['GET'], self._handle_token_request, protocol=['http']) async def _handle_apikey_request(self, web_request): action = web_request.get_action() if action.upper() == 'POST': self.api_key = uuid.uuid4().hex self.users[f'{API_USER}.api_key'] = self.api_key return self.api_key async def _handle_token_request(self, web_request): ip = web_request.get_ip_address() user_info = web_request.get_current_user() return self.get_oneshot_token(ip, user_info) async def _handle_login(self, web_request): return self._login_jwt_user(web_request) async def _handle_logout(self, web_request): user_info = web_request.get_current_user() if user_info is None: raise self.server.error("No user logged in") username = user_info['username'] if username in RESERVED_USERS: raise self.server.error( f"Invalid log out request for user {username}") self.users.pop(f"{username}.jwt_secret", None) return {"username": username, "action": "user_logged_out"} async def _handle_refresh_jwt(self, web_request): refresh_token = web_request.get_str('refresh_token') user_info = self._decode_jwt(refresh_token, token_type="refresh") username = user_info['username'] secret = bytes.fromhex(user_info['jwt_secret']) token = self._generate_jwt(username, secret) return { 'username': username, 'token': token, 'action': 'user_jwt_refresh' } async def _handle_user_request(self, web_request): action = web_request.get_action() if action == "GET": user = web_request.get_current_user() if user is None: return { 'username': None, 'created_on': None, } else: return { 'username': user['username'], 'created_on': user.get('created_on') } elif action == "POST": # Create User return self._login_jwt_user(web_request, create=True) elif action == "DELETE": # Delete User return self._delete_jwt_user(web_request) async def _handle_password_reset(self, web_request): password = web_request.get_str('password') new_pass = web_request.get_str('new_password') user_info = web_request.get_current_user() if user_info is None: raise self.server.error("No Current User") username = user_info['username'] if username in RESERVED_USERS: raise self.server.error( f"Invalid Reset Request for user {username}") salt = bytes.fromhex(user_info['salt']) hashed_pass = hashlib.pbkdf2_hmac('sha256', password.encode(), salt, HASH_ITER).hex() if hashed_pass != user_info['password']: raise self.server.error("Invalid Password") new_hashed_pass = hashlib.pbkdf2_hmac('sha256', new_pass.encode(), salt, HASH_ITER).hex() self.users[f'{username}.password'] = new_hashed_pass return {'username': username, 'action': "user_password_reset"} def _login_jwt_user(self, web_request, create=False): username = web_request.get_str('username') password = web_request.get_str('password') if username in RESERVED_USERS: raise self.server.error(f"Invalid Request for user {username}") if create: if username in self.users: raise self.server.error(f"User {username} already exists") salt = secrets.token_bytes(32) hashed_pass = hashlib.pbkdf2_hmac('sha256', password.encode(), salt, HASH_ITER).hex() user_info = { 'username': username, 'password': hashed_pass, 'salt': salt.hex(), 'created_on': time.time() } self.users[username] = user_info action = "user_created" else: if username not in self.users: raise self.server.error(f"Unregistered User: {username}") user_info = self.users[username] salt = bytes.fromhex(user_info['salt']) hashed_pass = hashlib.pbkdf2_hmac('sha256', password.encode(), salt, HASH_ITER).hex() action = "user_logged_in" if hashed_pass != user_info['password']: raise self.server.error("Invalid Password") jwt_secret = user_info.get('jwt_secret', None) if jwt_secret is None: jwt_secret = secrets.token_bytes(32) user_info['jwt_secret'] = jwt_secret.hex() self.users[username] = user_info else: jwt_secret = bytes.fromhex(jwt_secret) token = self._generate_jwt(username, jwt_secret) refresh_token = self._generate_jwt( username, jwt_secret, token_type="refresh", exp_time=datetime.timedelta(days=self.login_timeout)) return { 'username': username, 'token': token, 'refresh_token': refresh_token, 'action': action } def _delete_jwt_user(self, web_request): password = web_request.get_str('password') user_info = web_request.get_current_user() if user_info is None: raise self.server.error("No Current User") username = user_info['username'] if username in RESERVED_USERS: raise self.server.error(f"Invalid request for user {username}") salt = bytes.fromhex(user_info['salt']) hashed_pass = hashlib.pbkdf2_hmac('sha256', password.encode(), salt, HASH_ITER).hex() if hashed_pass != user_info['password']: raise self.server.error("Invalid Password") del self.users[username] return {"username": username, "action": "user_deleted"} def _generate_jwt(self, username, secret, token_type="auth", exp_time=JWT_EXP_TIME): curtime = time.time() payload = { 'iss': "Moonraker", 'iat': curtime, 'exp': curtime + exp_time.total_seconds(), 'username': username, 'token_type': token_type } enc_header = base64url_encode(json.dumps(JWT_HEADER).encode()) enc_payload = base64url_encode(json.dumps(payload).encode()) message = enc_header + b"." + enc_payload signature = base64url_encode(hmac.digest(secret, message, "sha256")) message += b"." + signature return message.decode() def _decode_jwt(self, jwt, token_type="auth"): parts = jwt.encode().split(b".") if len(parts) != 3: raise self.server.error(f"Invalid JWT length of {len(parts)}") header = json.loads(base64url_decode(parts[0])) payload = json.loads(base64url_decode(parts[1])) if header != JWT_HEADER: raise self.server.error("Invalid JWT header") recd_type = payload.get('token_type', "") if token_type != recd_type: raise self.server.error( f"JWT Token type mismatch: Expected {token_type}, " f"Recd: {recd_type}", 401) if time.time() > payload['exp']: raise self.server.error("JWT expired", 401) username = payload.get('username') user_info = self.users.get(username, None) if user_info is None: raise self.server.error( f"Invalid JWT, no registered user {username}", 401) jwt_secret = user_info.get('jwt_secret', None) if jwt_secret is None: raise self.server.error( f"Invalid JWT, user {username} not logged in", 401) secret = bytes.fromhex(jwt_secret) # Decode and verify signature signature = base64url_decode(parts[2]) calc_sig = hmac.digest(secret, parts[0] + b"." + parts[1], "sha256") if signature != calc_sig: raise self.server.error("Invalid JWT signature") return user_info def _prune_conn_handler(self): cur_time = time.time() for ip, user_info in list(self.trusted_users.items()): exp_time = user_info['expires_at'] if cur_time >= exp_time: self.trusted_users.pop(ip, None) logging.info(f"Trusted Connection Expired, IP: {ip}") def _oneshot_token_expire_handler(self, token): self.oneshot_tokens.pop(token, None) def get_oneshot_token(self, ip_addr, user): token = base64.b32encode(os.urandom(20)).decode() ioloop = IOLoop.current() hdl = ioloop.call_later(ONESHOT_TIMEOUT, self._oneshot_token_expire_handler, token) self.oneshot_tokens[token] = (ip_addr, user, hdl) return token def _check_json_web_token(self, request): auth_token = request.headers.get("Authorization") if auth_token is None: auth_token = request.headers.get("X-Access-Token") if auth_token and auth_token.startswith("Bearer "): auth_token = auth_token[7:] try: return self._decode_jwt(auth_token) except Exception as e: raise HTTPError(401, str(e)) return None def _check_authorized_ip(self, ip): if ip in self.trusted_ips: return True for rng in self.trusted_ranges: if ip in rng: return True fqdn = socket.getfqdn(str(ip)).lower() if fqdn in self.trusted_domains: return True return False def _check_trusted_connection(self, ip): if ip is not None: curtime = time.time() exp_time = curtime + TRUSTED_CONNECTION_TIMEOUT if ip in self.trusted_users: self.trusted_users[ip]['expires_at'] = exp_time return self.trusted_users[ip] elif self._check_authorized_ip(ip): logging.info(f"Trusted Connection Detected, IP: {ip}") self.trusted_users[ip] = { 'username': TRUSTED_USER, 'password': None, 'created_on': curtime, 'expires_at': exp_time } return self.trusted_users[ip] return None def _check_oneshot_token(self, token, cur_ip): if token in self.oneshot_tokens: ip_addr, user, hdl = self.oneshot_tokens.pop(token) IOLoop.current().remove_timeout(hdl) if cur_ip != ip_addr: logging.info(f"Oneshot Token IP Mismatch: expected{ip_addr}" f", Recd: {cur_ip}") return None return user else: return None def check_authorized(self, request): if request.path in self.permitted_paths: return None # Check JSON Web Token jwt_user = self._check_json_web_token(request) if jwt_user is not None: return jwt_user try: ip = ipaddress.ip_address(request.remote_ip) except ValueError: logging.exception( f"Unable to Create IP Address {request.remote_ip}") ip = None # Check oneshot access token ost = request.arguments.get('token', None) if ost is not None: ost_user = self._check_oneshot_token(ost[-1].decode(), ip) if ost_user is not None: return ost_user # Check API Key Header key = request.headers.get("X-Api-Key") if key and key == self.api_key: return self.users[API_USER] # Check if IP is trusted trusted_user = self._check_trusted_connection(ip) if trusted_user is not None: return trusted_user raise HTTPError(401, "Unauthorized") def check_cors(self, origin, request=None): if origin is None or not self.cors_domains: return False for regex in self.cors_domains: match = re.match(regex, origin) if match is not None: if match.group() == origin: logging.debug(f"CORS Pattern Matched, origin: {origin} " f" | pattern: {regex}") self._set_cors_headers(origin, request) return True else: logging.debug(f"Partial Cors Match: {match.group()}") else: # Check to see if the origin contains an IP that matches a # current trusted connection match = re.search(r"^https?://([^/:]+)", origin) if match is not None: ip = match.group(1) try: ipaddr = ipaddress.ip_address(ip) except ValueError: pass else: if self._check_authorized_ip(ipaddr): logging.debug(f"Cors request matched trusted IP: {ip}") self._set_cors_headers(origin, request) return True logging.debug(f"No CORS match for origin: {origin}\n" f"Patterns: {self.cors_domains}") return False def _set_cors_headers(self, origin, request): if request is None: return request.set_header("Access-Control-Allow-Origin", origin) request.set_header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") request.set_header( "Access-Control-Allow-Headers", "Origin, Accept, Content-Type, X-Requested-With, " "X-CRSF-Token, Authorization, X-Access-Token, " "X-Api-Key") def close(self): self.prune_handler.stop()
def test__lifecycle_hooks(): application = Application() handler = HookTestHandler() application.add(handler) with ManagedServerLoop(application, check_unused_sessions_milliseconds=30) as server: # wait for server callbacks to run before we mix in the # session, this keeps the test deterministic def check_done(): if len(handler.hooks) == 4: server.io_loop.stop() server_load_checker = PeriodicCallback(check_done, 1, io_loop=server.io_loop) server_load_checker.start() server.io_loop.start() server_load_checker.stop() # now we create a session client_session = pull_session(session_id='test__lifecycle_hooks', url=url(server), io_loop=server.io_loop) client_doc = client_session.document assert len(client_doc.roots) == 1 server_session = server.get_session('/', client_session.id) server_doc = server_session.document assert len(server_doc.roots) == 1 client_session.close() # expire the session quickly rather than after the # usual timeout server_session.request_expiration() def on_done(): server.io_loop.stop() server.io_loop.call_later(0.1, on_done) server.io_loop.start() assert handler.hooks == [ "server_loaded", "next_tick_server", "timeout_server", "periodic_server", "session_created", "next_tick_session", "modify", "timeout_session", "periodic_session", "session_destroyed", "server_unloaded" ] client_hook_list = client_doc.roots[0] server_hook_list = server_doc.roots[0] assert handler.load_count == 1 assert handler.unload_count == 1 assert handler.session_creation_async_value == 6 assert client_doc.title == "Modified" assert server_doc.title == "Modified" # the client session doesn't see the event that adds "session_destroyed" since # we shut down at that point. assert client_hook_list.hooks == ["session_created", "modify"] assert server_hook_list.hooks == [ "session_created", "modify", "session_destroyed" ]
class PeriodicCallback(param.Parameterized): """ Periodic encapsulates a periodic callback which will run both in tornado based notebook environments and on bokeh server. By default the callback will run until the stop method is called, but count and timeout values can be set to limit the number of executions or the maximum length of time for which the callback will run. The callback may also be started and stopped by setting the running parameter to True or False respectively. """ callback = param.Callable(doc=""" The callback to execute periodically.""") count = param.Integer(default=None, doc=""" Number of times the callback will be executed, by default this is unlimited.""") period = param.Integer(default=500, doc=""" Period in milliseconds at which the callback is executed.""") timeout = param.Integer(default=None, doc=""" Timeout in milliseconds from the start time at which the callback expires.""") running = param.Boolean(default=False, doc=""" Toggles whether the periodic callback is currently running.""") def __init__(self, **params): super().__init__(**params) self._counter = 0 self._start_time = None self._cb = None self._updating = False self._doc = None @param.depends('running', watch=True) def _start(self): if not self.running or self._updating: return self.start() @param.depends('running', watch=True) def _stop(self): if self.running or self._updating: return self.stop() @param.depends('period', watch=True) def _update_period(self): if self._cb: self.stop() self.start() def _periodic_callback(self): with edit_readonly(state): state.busy = True try: self.callback() finally: with edit_readonly(state): state.busy = False self._counter += 1 if self.timeout is not None: dt = (time.time() - self._start_time) * 1000 if dt > self.timeout: self.stop() if self._counter == self.count: self.stop() @property def counter(self): """ Returns the execution count of the periodic callback. """ return self._counter def _cleanup(self, session_context): self.stop() def start(self): """ Starts running the periodic callback. """ if self._cb is not None: raise RuntimeError('Periodic callback has already started.') if not self.running: try: self._updating = True self.running = True finally: self._updating = False self._start_time = time.time() if state.curdoc: self._doc = state.curdoc self._cb = self._doc.add_periodic_callback(self._periodic_callback, self.period) else: from tornado.ioloop import PeriodicCallback self._cb = PeriodicCallback(self._periodic_callback, self.period) self._cb.start() try: state.on_session_destroyed(self._cleanup) except Exception: pass def stop(self): """ Stops running the periodic callback. """ if self.running: try: self._updating = True self.running = False finally: self._updating = False self._counter = 0 self._timeout = None if self._doc: self._doc.remove_periodic_callback(self._cb) elif self._cb: self._cb.stop() self._cb = None doc = self._doc or _curdoc() if doc: doc.session_destroyed_callbacks = { cb for cb in doc.session_destroyed_callbacks if cb is not self._cleanup } self._doc = None
class UpdateManager: def __init__(self, config): self.server = config.get_server() self.config = config self.config.read_supplemental_config(SUPPLEMENTAL_CFG_PATH) auto_refresh_enabled = config.getboolean('enable_auto_refresh', False) self.distro = config.get('distro', "debian").lower() if self.distro not in SUPPORTED_DISTROS: raise config.error(f"Unsupported distro: {self.distro}") self.cmd_helper = CommandHelper(config) env = sys.executable mooncfg = self.config[f"update_manager static {self.distro} moonraker"] self.updaters = { "system": PackageUpdater(self.cmd_helper), "moonraker": GitUpdater(mooncfg, self.cmd_helper, MOONRAKER_PATH, env) } # TODO: Check for client config in [update_manager]. This is # deprecated and will be removed. client_repo = config.get("client_repo", None) if client_repo is not None: client_path = config.get("client_path") name = client_repo.split("/")[-1] self.updaters[name] = WebUpdater( { 'repo': client_repo, 'path': client_path }, self.cmd_helper) client_sections = self.config.get_prefix_sections( "update_manager client") for section in client_sections: cfg = self.config[section] name = section.split()[-1] if name in self.updaters: raise config.error("Client repo named %s already added" % (name, )) client_type = cfg.get("type") if client_type == "git_repo": self.updaters[name] = GitUpdater(cfg, self.cmd_helper) elif client_type == "web": self.updaters[name] = WebUpdater(cfg, self.cmd_helper) else: raise config.error("Invalid type '%s' for section [%s]" % (client_type, section)) self.cmd_request_lock = Lock() self.is_refreshing = False # Auto Status Refresh self.last_auto_update_time = 0 self.refresh_cb = None if auto_refresh_enabled: self.refresh_cb = PeriodicCallback(self._handle_auto_refresh, UPDATE_REFRESH_INTERVAL_MS) self.refresh_cb.start() self.server.register_endpoint("/machine/update/moonraker", ["POST"], self._handle_update_request) self.server.register_endpoint("/machine/update/klipper", ["POST"], self._handle_update_request) self.server.register_endpoint("/machine/update/system", ["POST"], self._handle_update_request) self.server.register_endpoint("/machine/update/client", ["POST"], self._handle_update_request) self.server.register_endpoint("/machine/update/status", ["GET"], self._handle_status_request) self.server.register_notification("update_manager:update_response") self.server.register_notification("update_manager:update_refreshed") # Register Ready Event self.server.register_event_handler("server:klippy_identified", self._set_klipper_repo) # Initialize GitHub API Rate Limits and configured updaters IOLoop.current().spawn_callback(self._initalize_updaters, list(self.updaters.values())) async def _initalize_updaters(self, initial_updaters): self.is_refreshing = True await self.cmd_helper.init_api_rate_limit() for updater in initial_updaters: if isinstance(updater, PackageUpdater): ret = updater.refresh(False) else: ret = updater.refresh() if asyncio.iscoroutine(ret): await ret self.is_refreshing = False async def _set_klipper_repo(self): kinfo = self.server.get_klippy_info() if not kinfo: logging.info("No valid klippy info received") return kpath = kinfo['klipper_path'] env = kinfo['python_path'] kupdater = self.updaters.get('klipper', None) if kupdater is not None and kupdater.repo_path == kpath and \ kupdater.env == env: # Current Klipper Updater is valid return kcfg = self.config[f"update_manager static {self.distro} klipper"] self.updaters['klipper'] = GitUpdater(kcfg, self.cmd_helper, kpath, env) await self.updaters['klipper'].refresh() async def _check_klippy_printing(self): klippy_apis = self.server.lookup_plugin('klippy_apis') result = await klippy_apis.query_objects({'print_stats': None}, default={}) pstate = result.get('print_stats', {}).get('state', "") return pstate.lower() == "printing" async def _handle_auto_refresh(self): if await self._check_klippy_printing(): # Don't Refresh during a print logging.info("Klippy is printing, auto refresh aborted") return cur_time = time.time() cur_hour = time.localtime(cur_time).tm_hour time_diff = cur_time - self.last_auto_update_time # Update packages if it has been more than 12 hours # and the local time is between 12AM and 5AM if time_diff < MIN_REFRESH_TIME or cur_hour >= MAX_PKG_UPDATE_HOUR: # Not within the update time window return self.last_auto_update_time = cur_time vinfo = {} need_refresh_all = not self.is_refreshing async with self.cmd_request_lock: self.is_refreshing = True try: for name, updater in list(self.updaters.items()): if need_refresh_all: ret = updater.refresh() if asyncio.iscoroutine(ret): await ret if hasattr(updater, "get_update_status"): vinfo[name] = updater.get_update_status() except Exception: logging.exception("Unable to Refresh Status") return finally: self.is_refreshing = False uinfo = self.cmd_helper.get_rate_limit_stats() uinfo['version_info'] = vinfo uinfo['busy'] = self.cmd_helper.is_update_busy() self.server.send_event("update_manager:update_refreshed", uinfo) async def _handle_update_request(self, web_request): if await self._check_klippy_printing(): raise self.server.error("Update Refused: Klippy is printing") app = web_request.get_endpoint().split("/")[-1] if app == "client": app = web_request.get('name') inc_deps = web_request.get_boolean('include_deps', False) if self.cmd_helper.is_app_updating(app): return f"Object {app} is currently being updated" updater = self.updaters.get(app, None) if updater is None: raise self.server.error(f"Updater {app} not available") async with self.cmd_request_lock: self.cmd_helper.set_update_info(app, id(web_request)) try: await updater.update(inc_deps) except Exception as e: self.cmd_helper.notify_update_response(f"Error updating {app}") self.cmd_helper.notify_update_response(str(e), is_complete=True) raise finally: self.cmd_helper.clear_update_info() return "ok" async def _handle_status_request(self, web_request): check_refresh = web_request.get_boolean('refresh', False) # Don't refresh if a print is currently in progress or # if an update is in progress. Just return the current # state if self.cmd_helper.is_update_busy() or \ await self._check_klippy_printing(): check_refresh = False need_refresh = False if check_refresh: # If there is an outstanding request processing a # refresh, we don't need to do it again. need_refresh = not self.is_refreshing await self.cmd_request_lock.acquire() self.is_refreshing = True vinfo = {} try: for name, updater in list(self.updaters.items()): await updater.check_initialized(120.) if need_refresh: ret = updater.refresh() if asyncio.iscoroutine(ret): await ret if hasattr(updater, "get_update_status"): vinfo[name] = updater.get_update_status() except Exception: raise finally: if check_refresh: self.is_refreshing = False self.cmd_request_lock.release() ret = self.cmd_helper.get_rate_limit_stats() ret['version_info'] = vinfo ret['busy'] = self.cmd_helper.is_update_busy() return ret def close(self): self.cmd_helper.close() if self.refresh_cb is not None: self.refresh_cb.stop()
class AzurePreemptibleWorkerPlugin(WorkerPlugin): """A worker plugin for azure spot instances This worker plugin will poll azure's metadata service for preemption notifications. When a node is preempted, the plugin will attempt to shutdown gracefully all workers on the node. This plugin can be used on any worker running on azure spot instances, not just the ones created by ``dask-cloudprovider``. For more details on azure spot instances see: https://docs.microsoft.com/en-us/azure/virtual-machines/linux/scheduled-events Parameters ---------- poll_interval_s: int (optional) The rate at which the plugin will poll the metadata service in seconds. Defaults to ``1`` metadata_url: str (optional) The url of the metadata service to poll. Defaults to "http://169.254.169.254/metadata/scheduledevents?api-version=2019-08-01" termination_events: List[str] (optional) The type of events that will trigger the gracefull shutdown Defaults to ``['Preempt', 'Terminate']`` termination_offset_minutes: int (optional) Extra offset to apply to the premption date. This may be negative, to start the gracefull shutdown before the ``NotBefore`` date. It can also be positive, to start the shutdown after the ``NotBefore`` date, but this is at your own risk. Defaults to ``0`` Examples -------- Let's say you have cluster and a client instance. For example using :class:`dask_kubernetes.KubeCluster` >>> from dask_kubernetes import KubeCluster >>> from distributed import Client >>> cluster = KubeCluster() >>> client = Client(cluster) You can add the worker plugin using the following: >>> from dask_cloudprovider.azure import AzurePreemptibleWorkerPlugin >>> client.register_worker_plugin(AzurePreemptibleWorkerPlugin()) """ def __init__( self, poll_interval_s=1, metadata_url=None, termination_events=None, termination_offset_minutes=0, ): self.callback = None self.loop = None self.worker = None self.poll_interval_s = poll_interval_s self.metadata_url = metadata_url or AZURE_EVENTS_METADATA_URL self.termination_events = termination_events or [ "Preempt", "Terminate" ] self.termination_offset = datetime.timedelta( minutes=termination_offset_minutes) self.terminating = False self.not_before = None self._session = None self._lock = None async def _is_terminating(self): preempt_started = False async with self._session.get(self.metadata_url) as response: try: data = await response.json() # Sometime azure responds with text/plain mime type except aiohttp.ContentTypeError: return # Sometimes the response doesn't contain the Events key events = data.get("Events", []) if events: logger.debug("Worker {}, got metadata events {}".format( self.worker.name, events)) for evt in events: event_type = evt["EventType"] if event_type not in self.termination_events: continue event_status = evt.get("EventStatus") if event_status == "Started": logger.info("Worker {}, node preemption started".format( self.worker.name)) preempt_started = True break not_before = evt.get("NotBefore") if not not_before: continue not_before = datetime.datetime.strptime( not_before, "%a, %d %b %Y %H:%M:%S GMT") if self.not_before is None: logger.info( "Worker {}, node deletion scheduled not before {}". format(self.worker.name, self.not_before)) self.not_before = not_before break if self.not_before < not_before: logger.info( "Worker {}, node deletion re-scheduled not before {}". format(self.worker.name, not_before)) self.not_before = not_before break return preempt_started or (self.not_before and (self.not_before + self.termination_offset < datetime.datetime.utcnow())) async def poll_status(self): if self.terminating: return if self._session is None: self._session = aiohttp.ClientSession(headers={"Metadata": "true"}) if self._lock is None: self._lock = asyncio.Lock() async with self._lock: is_terminating = await self._is_terminating() if not is_terminating: return logger.info( "Worker {}, node is being deleted, attempting graceful shutdown" .format(self.worker.name)) self.terminating = True await self._session.close() await self.worker.close_gracefully() def setup(self, worker): self.worker = worker self.loop = IOLoop.current() self.callback = PeriodicCallback(self.poll_status, callback_time=self.poll_interval_s * 1_000) self.loop.add_callback(self.callback.start) logger.debug("Worker {}, registering preemptible plugin".format( self.worker.name)) def teardown(self, worker): logger.debug("Worker {}, tearing down plugin".format(self.worker.name)) if self.callback: self.callback.stop() self.callback = None
class Client(object): """Kodi Connect Websocket Connection""" def __init__(self, url, kodi, handler): self.url = url self.websocket = None self.connected = False self.should_stop = False self.kodi = kodi self.handler = handler self.periodic = PeriodicCallback(self.periodic_callback, 20000) def start(self): """Start IO loop and try to connect to the server""" self.connect() self.periodic.start() IOLoop.current().start() def stop(self): """Stop IO loop""" self.should_stop = True if self.websocket is not None: self.websocket.close() self.periodic.stop() IOLoop.current().stop() @gen.coroutine def connect(self): """Connect to the server and update connection to websocket""" email = __addon__.getSetting('email') secret = __addon__.getSetting('secret') if not email or not secret: logger.debug('Email and/or secret not defined, not connecting') return logger.debug('trying to connect') try: request = HTTPRequest(self.url, auth_username=email, auth_password=secret) self.websocket = yield websocket_connect(request) except Exception as ex: logger.debug('connection error: {}'.format(str(ex))) self.websocket = None notification(strings.FAILED_TO_CONNECT, level='error', tag='connection') else: logger.debug('Connected') self.connected = True notification(strings.CONNECTED, tag='connection') self.run() @gen.coroutine def run(self): """Main loop handling incomming messages""" while True: message_str = yield self.websocket.read_message() if message_str is None: logger.debug('Connection closed') self.websocket = None notification(strings.DISCONNECTED, level='warn', tag='connection') break try: message = json.loads(message_str) logger.debug(message) data = message['data'] response_data = self.handler.handler(data) except Exception as ex: logger.error('Handler failed: {}'.format(str(ex))) response_data = {"status": "error", "error": "Unknown error"} self.websocket.write_message( json.dumps({ "correlationId": message['correlationId'], "data": response_data })) def periodic_callback(self): """Periodic callback""" logger.debug('periodic_callback') if self.websocket is None: self.connect() else: self.websocket.write_message(json.dumps({"ping": "pong"})) try: self.kodi.update_cache() except Exception as ex: logger.error('Failed to update Kodi library: {}'.format(str(ex)))
class ProcStats: def __init__(self, config: ConfigHelper) -> None: self.server = config.get_server() self.ioloop = IOLoop.current() self.stat_update_cb = PeriodicCallback( self._handle_stat_update, STAT_UPDATE_TIME_MS) # type: ignore self.vcgencmd: Optional[shell_command.ShellCommand] = None if os.path.exists(VC_GEN_CMD_FILE): logging.info("Detected 'vcgencmd', throttle checking enabled") shell_cmd: shell_command.ShellCommandFactory shell_cmd = self.server.load_component(config, "shell_command") self.vcgencmd = shell_cmd.build_shell_command( "vcgencmd get_throttled") self.server.register_notification("proc_stats:cpu_throttled") else: logging.info("Unable to find 'vcgencmd', throttle checking " "disabled") self.temp_file = pathlib.Path(TEMPERATURE_PATH) self.smaps = pathlib.Path(STATM_FILE_PATH) self.server.register_endpoint( "/machine/proc_stats", ["GET"], self._handle_stat_request) self.server.register_event_handler( "server:klippy_shutdown", self._handle_shutdown) self.server.register_notification("proc_stats:proc_stat_update") self.proc_stat_queue: Deque[Dict[str, Any]] = deque(maxlen=30) self.last_update_time = time.time() self.last_proc_time = time.process_time() self.throttle_check_lock = Lock() self.total_throttled: int = 0 self.last_throttled: int = 0 self.update_sequence: int = 0 self.stat_update_cb.start() async def _handle_stat_request(self, web_request: WebRequest ) -> Dict[str, Any]: ts: Optional[Dict[str, Any]] = None if self.vcgencmd is not None: ts = await self._check_throttled_state() return { 'moonraker_stats': list(self.proc_stat_queue), 'throttled_state': ts, 'cpu_temp': self._get_cpu_temperature() } async def _handle_shutdown(self) -> None: msg = "\nMoonraker System Usage Statistics:" for stats in self.proc_stat_queue: msg += f"\n{self._format_stats(stats)}" msg += f"\nCPU Temperature: {self._get_cpu_temperature()}" logging.info(msg) if self.vcgencmd is not None: ts = await self._check_throttled_state() logging.info(f"Throttled Flags: {' '.join(ts['flags'])}") async def _handle_stat_update(self) -> None: update_time = time.time() proc_time = time.process_time() time_diff = update_time - self.last_update_time usage = round((proc_time - self.last_proc_time) / time_diff * 100, 2) mem, mem_units = self._get_memory_usage() cpu_temp = self._get_cpu_temperature() result = { "time": update_time, "cpu_usage": usage, "memory": mem, "mem_units": mem_units, } self.proc_stat_queue.append(result) self.server.send_event("proc_stats:proc_stat_update", { 'moonraker_stats': result, 'cpu_temp': cpu_temp }) self.last_update_time = update_time self.last_proc_time = proc_time self.update_sequence += 1 if self.update_sequence == THROTTLE_CHECK_INTERVAL: self.update_sequence = 0 if self.vcgencmd is not None: ts = await self._check_throttled_state() cur_throttled = ts['bits'] if cur_throttled & ~self.total_throttled: self.server.add_log_rollover_item( 'throttled', f"CPU Throttled Flags: {ts['flags']}") if cur_throttled != self.last_throttled: self.server.send_event("proc_stats:cpu_throttled", ts) self.last_throttled = cur_throttled self.total_throttled |= cur_throttled async def _check_throttled_state(self) -> Dict[str, Any]: async with self.throttle_check_lock: assert self.vcgencmd is not None try: resp = await self.vcgencmd.run_with_response( timeout=.5, log_complete=False) ts = int(resp.strip().split("=")[-1], 16) except Exception: return {'bits': 0, 'flags': ["?"]} flags = [] for flag, desc in THROTTLED_FLAGS.items(): if flag & ts: flags.append(desc) return {'bits': ts, 'flags': flags} def _get_memory_usage(self) -> Tuple[Optional[int], Optional[str]]: try: mem_data = self.smaps.read_text() rss_match = re.search(r"Rss:\s+(\d+)\s+(\w+)", mem_data) if rss_match is None: return None, None mem = int(rss_match.group(1)) units = rss_match.group(2) except Exception: return None, None return mem, units def _get_cpu_temperature(self) -> Optional[float]: temp = None if self.temp_file.exists(): try: res = int(self.temp_file.read_text().strip()) temp = res / 1000. except Exception: return None return temp def _format_stats(self, stats: Dict[str, Any]) -> str: return f"System Time: {stats['time']:2f}, " \ f"Usage: {stats['cpu_usage']}%, " \ f"Memory: {stats['memory']} {stats['mem_units']}" def close(self) -> None: self.stat_update_cb.stop()
class JobScheduler: """a FIFO scheduler""" def __init__(self, batch, exit_on_finish, interval, executor_manager: ExecutorManager, callbacks=None): self.batch = batch self.exit_on_finish = exit_on_finish self.executor_manager = executor_manager self.callbacks = callbacks if callbacks is not None else [] self._io_loop_instance = None self._timer = PeriodicCallback(self.attempt_scheduling, interval) self._n_skipped = 0 self._n_allocated = 0 self._selected_jobs = [] @property def n_skipped(self): return self._n_skipped @property def n_allocated(self): return self._n_allocated @property def interval(self): return self._timer.callback_time def start(self): self.executor_manager.prepare() self._timer.start() self.batch.start_time = time.time() # stats finished jobs for job in self.batch.jobs: job_status = self.batch.get_job_status(job.name) if job_status != ShellJob.STATUS_INIT: logger.info( f"job '{job.name}' status is '{job_status}', skip run.") self._n_skipped = self.n_skipped + 1 else: self._selected_jobs.append(job) for callback in self.callbacks: callback.on_start(self.batch) # run in io loop self._io_loop_instance = tornado.ioloop.IOLoop.instance() logger.info('starting io loop') self._io_loop_instance.start() logger.info('exited io loop') def stop(self): if self._io_loop_instance is not None: self._io_loop_instance.add_callback( self._io_loop_instance.stop) # let ioloop stop itself # self._io_loop_instance.stop() # This is not work for another Thread to stop the ioloop logger.info("add a stop callback to ioloop") else: raise RuntimeError("Not started yet") def kill_job(self, job_name): # checkout job job: ShellJob = self.batch.get_job_by_name(job_name) if job is None: raise ValueError(f'job {job_name} does not exists ') job_status = self.batch.get_job_status(job.name) logger.info( f"trying kill job {job_name}, it's status is {job_status} ") # check job status if job_status != job.STATUS_RUNNING: raise RuntimeError( f"job {job_name} in not in {job.STATUS_RUNNING} status but is {job_status} " ) # find executor and kill em = self.executor_manager executor = em.get_executor(job) logger.info(f"find executor {executor} of job {job_name}") if executor is not None: em.kill_executor(executor) logger.info(f"write failed status file for {job_name}") self._change_job_status(job, job.STATUS_FAILED) else: raise ValueError(f"no executor found for job {job.name}") def _change_job_status(self, job: ShellJob, next_status): self.change_job_status(self.batch, job, next_status) @staticmethod def change_job_status(batch: Batch, job: ShellJob, next_status): current_status = batch.get_job_status(job_name=job.name) target_status_file = batch.job_status_file_path(job_name=job.name, status=next_status) def touch(f_path): with open(f_path, 'w') as f: pass if next_status == job.STATUS_INIT: raise ValueError(f"can not change to {next_status} ") elif next_status == job.STATUS_RUNNING: if current_status != job.STATUS_INIT: raise ValueError( f"only job in {job.STATUS_INIT} can change to {next_status}" ) # job.set_status(next_status) touch(target_status_file) elif next_status in job.FINAL_STATUS: if current_status != job.STATUS_RUNNING: raise ValueError( f"only job in {job.STATUS_RUNNING} can change to " f"{next_status} but now is {current_status}") # remove running status os.remove( batch.job_status_file_path(job_name=job.name, status=job.STATUS_RUNNING)) # job.set_status(next_status) touch(target_status_file) reload_status = batch.get_job_status(job_name=job.name) assert reload_status == next_status, f"change job status failed, current status is {reload_status}," \ f" expected status is {next_status}" else: raise ValueError(f"unknown status {next_status}") def _release_executors(self, executor_manager): finished = [] for executor in executor_manager.waiting_executors(): executor: ShellExecutor = executor if executor.status() in ShellJob.FINAL_STATUS: finished.append(executor) for finished_executor in finished: executor_status = finished_executor.status() job = finished_executor.job logger.info( f"job {job.name} finished with status {executor_status}") self._change_job_status(job, finished_executor.status()) job.end_time = time.time() # update end time executor_manager.release_executor(finished_executor) self._handle_job_finished(job, finished_executor, job.elapsed) def _handle_callbacks(self, func): for callback in self.callbacks: try: callback: BatchCallback = callback func(callback) except Exception as e: logger.warning("handle callback failed", e) def _handle_job_start(self, job, executor): def f(callback): callback.on_job_start(self.batch, job, executor) self._handle_callbacks(f) def _handle_job_finished(self, job, executor, elapsed): def f(callback): callback.on_job_finish(self.batch, job, executor, elapsed) self._handle_callbacks(f) def _run_jobs(self, executor_manager): jobs = self._selected_jobs for job in jobs: if self.batch.get_job_status(job.name) != job.STATUS_INIT: # logger.info(f"job '{job.name}' status is {job.status}, skip run") continue # logger.debug(f'trying to alloc resource for job {job.name}') try: executor = executor_manager.alloc_executor(job) except NoResourceException: # logger.debug(f"no enough resource for job {job.name} , wait for resource to continue ...") break except Exception as e: # skip the job, and do not clean the executor self._change_job_status(job, job.STATUS_FAILED) # TODO on job break logger.exception( f"failed to alloc resource for job '{job.name}' ", e) continue self._n_allocated = self.n_allocated + 1 job.start_time = time.time() # update start time self._handle_job_start(job, executor) process_msg = f"{len(executor_manager.allocated_executors())}/{len(jobs)}" logger.info( f'allocated resource for job {job.name}({process_msg}), data dir at {job.data_dir_path}' ) self._change_job_status(job, job.STATUS_RUNNING) try: executor.run() except Exception as e: logger.exception(f"failed to run job '{job.name}' ", e) self._change_job_status(job, job.STATUS_FAILED) # TODO on job break executor_manager.release_executor(executor) continue finally: pass def _handle_on_finished(self): for callback in self.callbacks: callback: BatchCallback = callback callback.on_finish(self.batch, self.batch.elapsed) def attempt_scheduling(self): # attempt_scheduling # check all jobs finished job_finished = self.batch.is_finished() if job_finished: self.batch.end_time = time.time() batch_summary = json.dumps(self.batch.summary()) logger.info("all jobs finished, stop scheduler:\n" + batch_summary) self._timer.stop() # stop the timer if self.exit_on_finish: self.stop() self._handle_on_finished() return self._release_executors(self.executor_manager) self._run_jobs(self.executor_manager)
def test__lifecycle_hooks(ManagedServerLoop) -> None: application = Application() handler = HookTestHandler() application.add(handler) with ManagedServerLoop(application, check_unused_sessions_milliseconds=30) as server: # wait for server callbacks to run before we mix in the # session, this keeps the test deterministic def check_done(): if len(handler.hooks) == 4: server.io_loop.stop() server_load_checker = PeriodicCallback(check_done, 1) server_load_checker.start() server.io_loop.start() server_load_checker.stop() # now we create a session client_session = pull_session(session_id='test__lifecycle_hooks', url=url(server), io_loop=server.io_loop) client_doc = client_session.document assert len(client_doc.roots) == 1 server_session = server.get_session('/', client_session.id) server_doc = server_session.document assert len(server_doc.roots) == 1 # we have to capture these here for examination later, since after # the session is closed, doc.roots will be emptied client_hook_list = client_doc.roots[0] server_hook_list = server_doc.roots[0] client_session.close() # expire the session quickly rather than after the # usual timeout server_session.request_expiration() def on_done(): server.io_loop.stop() server.io_loop.call_later(0.1, on_done) server.io_loop.start() assert handler.hooks == ["server_loaded", "next_tick_server", "timeout_server", "periodic_server", "session_created", "modify", "next_tick_session", "timeout_session", "periodic_session", "session_destroyed", "server_unloaded"] assert handler.load_count == 1 assert handler.unload_count == 1 # this is 3 instead of 6 because locked callbacks on destroyed sessions # are turned into no-ops assert handler.session_creation_async_value == 3 assert client_doc.title == "Modified" assert server_doc.title == "Modified" # only the handler sees the event that adds "session_destroyed" since # the session is shut down at that point. assert client_hook_list.hooks == ["session_created", "modify"] assert server_hook_list.hooks == ["session_created", "modify"]
class JsonStatusServer(AsyncDeviceServer): VERSION_INFO = ("reynard-eff-jsonstatusserver-api", 0, 1) BUILD_INFO = ("reynard-eff-jsonstatusserver-implementation", 0, 1, "rc1") def __init__(self, server_host, server_port, mcast_group=JSON_STATUS_MCAST_GROUP, mcast_port=JSON_STATUS_PORT, parser=EFF_JSON_CONFIG, dummy=False): self._mcast_group = mcast_group self._mcast_port = mcast_port self._parser = parser self._dummy = dummy if not dummy: self._catcher_thread = StatusCatcherThread() else: self._catcher_thread = None self._monitor = None self._updaters = {} self._controlled = set() super(JsonStatusServer, self).__init__(server_host, server_port) @coroutine def _update_sensors(self): log.debug("Updating sensor values") data = self._catcher_thread.data if data is None: log.warning("Catcher thread has not received any data yet") return for name, params in self._parser.items(): if name in self._controlled: continue if "updater" in params: self._sensors[name].set_value(params["updater"](data)) def start(self): """start the server""" super(JsonStatusServer, self).start() if not self._dummy: self._catcher_thread.start() self._monitor = PeriodicCallback(self._update_sensors, 1000, io_loop=self.ioloop) self._monitor.start() def stop(self): """stop the server""" if not self._dummy: if self._monitor: self._monitor.stop() self._catcher_thread.stop() return super(JsonStatusServer, self).stop() @request() @return_reply(Str()) def request_xml(self, req): """request an XML version of the status message""" def make_elem(parent, name, text): child = etree.Element(name) child.text = text parent.append(child) @coroutine def convert(): try: root = etree.Element("TelescopeStatus", attrib={"timestamp": str(time.time())}) for name, sensor in self._sensors.items(): child = etree.Element("TelStat") make_elem(child, "Name", name) make_elem(child, "Value", str(sensor.value())) make_elem(child, "Status", str(sensor.status())) make_elem(child, "Type", self._parser[name]["type"]) if "units" in self._parser[name]: make_elem(child, "Units", self._parser[name]["units"]) root.append(child) except Exception as error: req.reply("ok", str(error)) else: req.reply("ok", etree.tostring(root)) self.ioloop.add_callback(convert) raise AsyncReply @request() @return_reply(Str()) def request_json(self, req): """request an JSON version of the status message""" return ("ok", self.as_json()) def as_json(self): """Convert status sensors to JSON object""" out = {} for name, sensor in self._sensors.items(): out[name] = str(sensor.value()) return json.dumps(out) @request(Str()) @return_reply(Str()) def request_sensor_control(self, req, name): """take control of a given sensor value""" if name not in self._sensors: return ("fail", "No sensor named '{0}'".format(name)) else: self._controlled.add(name) return ("ok", "{0} under user control".format(name)) @request() @return_reply(Str()) def request_sensor_control_all(self, req): """take control of all sensors value""" for name, sensor in self._sensors.items(): self._controlled.add(name) return ("ok", "{0} sensors under user control".format(len(self._controlled))) @request() @return_reply(Int()) def request_sensor_list_controlled(self, req): """List all controlled sensors""" count = len(self._controlled) for name in list(self._controlled): req.inform("{0} -- {1}".format(name, self._sensors[name].value())) return ("ok", count) @request(Str()) @return_reply(Str()) def request_sensor_release(self, req, name): """release a sensor from user control""" if name not in self._sensors: return ("fail", "No sensor named '{0}'".format(name)) else: self._controlled.remove(name) return ("ok", "{0} released from user control".format(name)) @request() @return_reply(Str()) def request_sensor_release_all(self, req): """take control of all sensors value""" self._controlled = set() return ("ok", "All sensors released") @request(Str(), Str()) @return_reply(Str()) def request_sensor_set(self, req, name, value): """Set the value of a sensor""" if name not in self._sensors: return ("fail", "No sensor named '{0}'".format(name)) if name not in self._controlled: return ("fail", "Sensor '{0}' not under user control".format(name)) try: param = self._parser[name] value = TYPE_CONVERTER[param["type"]](value) self._sensors[name].set_value(value) except Exception as error: return ("fail", str(error)) else: return ("ok", "{0} set to {1}".format(name, self._sensors[name].value())) def setup_sensors(self): """Set up basic monitoring sensors. """ for name, params in self._parser.items(): if params["type"] == "float": sensor = Sensor.float(name, description=params["description"], unit=params.get("units", None), default=params.get("default", 0.0), initial_status=Sensor.UNKNOWN) elif params["type"] == "string": sensor = Sensor.string(name, description=params["description"], default=params.get("default", ""), initial_status=Sensor.UNKNOWN) elif params["type"] == "int": sensor = Sensor.integer(name, description=params["description"], default=params.get("default", 0), unit=params.get("units", None), initial_status=Sensor.UNKNOWN) elif params["type"] == "bool": sensor = Sensor.boolean(name, description=params["description"], default=params.get("default", False), initial_status=Sensor.UNKNOWN) else: raise Exception("Unknown sensor type '{0}' requested".format( params["type"])) self.add_sensor(sensor)
class WorkStealing(SchedulerPlugin): def __init__(self, scheduler): self.scheduler = scheduler self.stealable_all = [set() for i in range(15)] self.stealable = dict() self.key_stealable = dict() self.stealable_unknown_durations = defaultdict(set) self.cost_multipliers = [1 + 2**(i - 6) for i in range(15)] self.cost_multipliers[0] = 1 for worker in scheduler.workers: self.add_worker(worker=worker) self._pc = PeriodicCallback(callback=self.balance, callback_time=100, io_loop=self.scheduler.loop) self.scheduler.loop.add_callback(self._pc.start) self.scheduler.plugins.append(self) self.scheduler.extensions['stealing'] = self self.scheduler.events['stealing'] = deque(maxlen=100000) self.count = 0 @property def log(self): return self.scheduler.events['stealing'] def add_worker(self, scheduler=None, worker=None): self.stealable[worker] = [set() for i in range(15)] def remove_worker(self, scheduler=None, worker=None): del self.stealable[worker] def teardown(self): self._pc.stop() def transition(self, key, start, finish, compute_start=None, compute_stop=None, *args, **kwargs): if finish == 'processing': self.put_key_in_stealable(key) if start == 'processing': self.remove_key_from_stealable(key) if finish == 'memory': ks = key_split(key) if ks in self.stealable_unknown_durations: for k in self.stealable_unknown_durations.pop(ks): if self.scheduler.task_state[k] == 'processing': self.put_key_in_stealable(k, split=ks) def put_key_in_stealable(self, key, split=None): worker = self.scheduler.rprocessing[key] cost_multiplier, level = self.steal_time_ratio(key, split=split) if cost_multiplier is not None: self.stealable_all[level].add(key) self.stealable[worker][level].add(key) self.key_stealable[key] = (worker, level) def remove_key_from_stealable(self, key): result = self.key_stealable.pop(key, None) if result is not None: worker, level = result try: self.stealable[worker][level].remove(key) except KeyError: pass try: self.stealable_all[level].remove(key) except KeyError: pass def steal_time_ratio(self, key, split=None): """ The compute to communication time ratio of a key Returns ------- cost_multiplier: The increased cost from moving this task as a factor. For example a result of zero implies a task without dependencies. level: The location within a stealable list to place this value """ if (key not in self.scheduler.loose_restrictions and (key in self.scheduler.host_restrictions or key in self.scheduler.worker_restrictions) or key in self.scheduler.resource_restrictions): return None, None # don't steal if not self.scheduler.dependencies[key]: # no dependencies fast path return 0, 0 nbytes = sum( self.scheduler.nbytes.get(k, 1000) for k in self.scheduler.dependencies[key]) transfer_time = nbytes / BANDWIDTH + LATENCY split = split or key_split(key) if split in fast_tasks: return None, None try: worker = self.scheduler.rprocessing[key] compute_time = self.scheduler.processing[worker][key] except KeyError: self.stealable_unknown_durations[split].add(key) return None, None else: if compute_time < 0.005: # 5ms, just give up return None, None cost_multiplier = transfer_time / compute_time if cost_multiplier > 100: return None, None level = int(round(log(cost_multiplier) / log_2 + 6, 0)) level = max(1, level) return cost_multiplier, level def move_task(self, key, victim, thief): try: if self.scheduler.validate: if victim != self.scheduler.rprocessing[key]: import pdb pdb.set_trace() self.remove_key_from_stealable(key) logger.debug("Moved %s, %s: %2f -> %s: %2f", key, victim, self.scheduler.occupancy[victim], thief, self.scheduler.occupancy[thief]) duration = self.scheduler.processing[victim].pop(key) self.scheduler.occupancy[victim] -= duration self.scheduler.total_occupancy -= duration duration = self.scheduler.task_duration.get(key_split(key), 0.5) duration += sum(self.scheduler.nbytes[key] for key in self.scheduler.dependencies[key] - self.scheduler.has_what[thief]) / BANDWIDTH self.scheduler.processing[thief][key] = duration self.scheduler.rprocessing[key] = thief self.scheduler.occupancy[thief] += duration self.scheduler.total_occupancy += duration self.put_key_in_stealable(key) self.scheduler.worker_comms[victim].send({ 'op': 'release-task', 'reason': 'stolen', 'key': key }) try: self.scheduler.send_task_to_worker(thief, key) except CommClosedError: self.scheduler.remove_worker(thief) except CommClosedError: logger.info("Worker comm closed while stealing: %s", victim) except Exception as e: logger.exception(e) if LOG_PDB: import pdb pdb.set_trace() raise def balance(self): with log_errors(): i = 0 s = self.scheduler occupancy = s.occupancy idle = s.idle saturated = s.saturated if not idle or len(idle) == len(self.scheduler.workers): return log = list() start = time() seen = False acted = False if not s.saturated: saturated = topk(10, s.workers, key=occupancy.get) saturated = [ w for w in saturated if occupancy[w] > 0.2 and len(s.processing[w]) > s.ncores[w] ] elif len(s.saturated) < 20: saturated = sorted(saturated, key=occupancy.get, reverse=True) if len(idle) < 20: idle = sorted(idle, key=occupancy.get) for level, cost_multiplier in enumerate(self.cost_multipliers): if not idle: break for sat in list(saturated): stealable = self.stealable[sat][level] if not stealable or not idle: continue else: seen = True for key in list(stealable): i += 1 if not idle: break idl = idle[i % len(idle)] duration = s.processing[sat][key] if (occupancy[idl] + cost_multiplier * duration <= occupancy[sat] - duration / 2): self.move_task(key, sat, idl) log.append((start, level, key, duration, sat, occupancy[sat], idl, occupancy[idl])) self.scheduler.check_idle_saturated(sat) self.scheduler.check_idle_saturated(idl) seen = True if self.cost_multipliers[ level] < 20: # don't steal from public at cost stealable = self.stealable_all[level] if stealable: seen = True for key in list(stealable): if not idle: break sat = s.rprocessing[key] if occupancy[sat] < 0.2: continue if len(s.processing[sat]) <= s.ncores[sat]: continue i += 1 idl = idle[i % len(idle)] duration = s.processing[sat][key] if (occupancy[idl] + cost_multiplier * duration <= occupancy[sat] - duration / 2): self.move_task(key, sat, idl) log.append((start, level, key, duration, sat, occupancy[sat], idl, occupancy[idl])) self.scheduler.check_idle_saturated(sat) self.scheduler.check_idle_saturated(idl) seen = True if seen and not acted: break if log: self.log.append(log) self.count += 1 stop = time() if self.scheduler.digests: self.scheduler.digests['steal-duration'].add(stop - start) def restart(self, scheduler): for stealable in self.stealable.values(): for s in stealable: s.clear() for s in self.stealable_all: s.clear() self.key_stealable.clear() self.stealable_unknown_durations.clear() def story(self, *keys): keys = set(keys) return [t for L in self.log for t in L if any(x in keys for x in t)]
class CrawlProcessHandlerBase(object): def __init__(self, game_params, username, logger): self.game_params = game_params self.username = username self.logger = logging.LoggerAdapter(logger, {}) try: self.logger.manager self.logger.process = self._process_log_msg except AttributeError: # This is a workaround for a python 3.5 bug with chained # LoggerAdapters, where delegation is not handled properly (e.g. # manager isn't set, _log isn't available, etc.). This simple fix # only handles two levels of chaining. The general fix is to # upgrade to python 3.7. # Issue: https://bugs.python.org/issue31457 self.logger = logging.LoggerAdapter(logger.logger, {}) self.logger.process = lambda m,k: logger.process(*self._process_log_msg(m, k)) self.queue_messages = False self.process = None self.client_path = self.config_path("client_path") self.crawl_version = None self.where = {} self.wheretime = 0 self.last_milestone = None self.kill_timeout = None self.muted = set() now = datetime.datetime.utcnow() self.formatted_time = now.strftime("%Y-%m-%d.%H:%M:%S") self.lock_basename = self.formatted_time + ".ttyrec" self.end_callback = None self._receivers = set() self.last_activity_time = time.time() self.idle_checker = PeriodicCallback(self.check_idle, 10000) self.idle_checker.start() self._was_idle = False self.last_watcher_join = 0 self.receiving_direct_milestones = False global last_game_id self.id = last_game_id + 1 last_game_id = self.id def _process_log_msg(self, msg, kwargs): return "P%-5s %s" % (self.id, msg), kwargs def format_path(self, path): return dgl_format_str(path, self.username, self.game_params) def config_path(self, key): if key not in self.game_params: return None base_path = self.format_path(self.game_params[key]) if key == "socket_path" and config.get('live_debug'): # TODO: this is kind of brute-force given that regular paths aren't # validated at all... debug_path = os.path.join(base_path, 'live-debug') if not os.path.isdir(debug_path): os.makedirs(debug_path) return debug_path else: return base_path def idle_time(self): return int(time.time() - self.last_activity_time) def is_idle(self): return self.idle_time() > 30 def check_idle(self): if self.is_idle() != self._was_idle: self._was_idle = self.is_idle() if config.get('dgl_mode'): update_all_lobbys(self) def flush_messages_to_all(self): for receiver in self._receivers: receiver.flush_messages() def write_to_all(self, msg, send): # type: (str, bool) -> None for receiver in self._receivers: receiver.append_message(msg, send) def send_to_all(self, msg, **data): # type: (str, Any) -> None for receiver in self._receivers: receiver.send_message(msg, **data) def chat_help_message(self, source, command, desc): if len(command) == 0: self.handle_notification_raw(source, " " * 8 + "<span>%s</span>" % (xhtml_escape(desc))) else: self.handle_notification_raw(source, " " * 4 + "<span>%s: %s</span>" % (xhtml_escape(command), xhtml_escape(desc))) def chat_command_help(self, source): # TODO: generalize # the chat window is basically fixed width, and these are calibrated # to not do a linewrap self.handle_notification(source, "The following chat commands are available:") self.chat_help_message(source, "/help", "show chat command help.") self.chat_help_message(source, "/hide", "hide the chat window.") if self.is_player(source): self.chat_help_message(source, "/mute <name>", "add <name> to the mute list.") self.chat_help_message(source, "", "Must be present in channel.") self.chat_help_message(source, "/mutelist", "show your entire mute list.") self.chat_help_message(source, "/unmute <name>", "remove <name> from the mute list.") self.chat_help_message(source, "/unmute *", "clear your mute list.") def handle_chat_command(self, source_ws, text): # type: (CrawlWebSocket, str) -> bool source = source_ws.username text = text.strip() if len(text) == 0 or text[0] != '/': return False splitlist = text.split(None, 1) if len(splitlist) == 1: command = splitlist[0] remainder = "" else: command, remainder = splitlist command = command.lower() # TODO: generalize if command == "/mute": self.mute(source, remainder) elif command == "/unmute": self.unmute(source, remainder) elif command == "/mutelist": self.show_mute_list(source) elif command == "/help": self.chat_command_help(source) elif command == "/hide": self.hide_chat(source_ws, remainder.strip()) else: return False return True def handle_chat_message(self, username, text): # type: (str, str) -> None if username in self.muted: # TODO: message? return chat_msg = ("<span class='chat_sender'>%s</span>: <span class='chat_msg'>%s</span>" % (username, xhtml_escape(text))) self.send_to_all("chat", content = chat_msg) def get_receivers_by_username(self, username): result = list() for w in self._receivers: if not w.username: continue if w.username == username: result.append(w) return result def get_primary_receiver(self): # TODO: does this work with console? Probably not... if self.username is None: return None receivers = self.get_receivers_by_username(self.username) for r in receivers: if not r.watched_game: return r return None def send_to_user(self, username, msg, **data): # type: (str, str, Any) -> None # a single user may be viewing from multiple receivers for receiver in self.get_receivers_by_username(username): receiver.send_message(msg, **data) # obviously, don't use this for player/spectator-accessible data. But, it # is still partially sanitized in chat.js. def handle_notification_raw(self, username, text): # type: (str, str) -> None msg = ("<span class='chat_msg'>%s</span>" % text) self.send_to_user(username, "chat", content=msg, meta=True) def handle_notification(self, username, text): # type: (str, str) -> None self.handle_notification_raw(username, xhtml_escape(text)) def handle_process_end(self): if self.kill_timeout: IOLoop.current().remove_timeout(self.kill_timeout) self.kill_timeout = None self.idle_checker.stop() # send game_ended message to watchers. The player is handled in cleanup # code in ws_handler.py. for watcher in list(self._receivers): if watcher.watched_game == self: watcher.send_message("game_ended", reason = self.exit_reason, message = self.exit_message, dump = self.exit_dump_url) watcher.go_lobby() if self.end_callback: self.end_callback() def get_watchers(self, chatting_only=False): # TODO: I don't understand why this code didn't just use self.username, # when will this be different than player_name? Maybe for a console # player? player_name = None watchers = list() for w in self._receivers: if not w.username: # anon continue if chatting_only and w.chat_hidden: continue if not w.watched_game: player_name = w.username else: watchers.append(w.username) watchers.sort(key=lambda s:s.lower()) return (player_name, watchers) def is_player(self, username): # TODO: probably doesn't work for console players spectating themselves # TODO: let admin accounts mute as well? player_name, watchers = self.get_watchers() return (username == player_name) def hide_chat(self, receiver, param): if param == "forever": receiver.send_message("super_hide_chat") receiver.chat_hidden = True # currently only for super hidden chat self.update_watcher_description() else: receiver.send_message("toggle_chat") def restore_mutelist(self, source, l): if not self.is_player(source) or l is None: return if len(l) == 0: return self.muted = {u for u in l if u != source} self.handle_notification(source, "Restoring mute list.") self.show_mute_list(source) self.logger.info("Player '%s' restoring mutelist %s" % (source, repr(list(self.muted)))) def save_mutelist(self, source): if not self.is_player(source): return receiver = self.get_primary_receiver() if receiver is not None: receiver.save_mutelist(list(self.muted)) def mute(self, source, target): if not self.is_player(source): self.handle_notification(source, "You do not have permission to mute spectators.") return False if (source == target): self.handle_notification(source, "You can't mute yourself!") return False player_name, watchers = self.get_watchers() watchers = set(watchers) if not target in watchers: self.handle_notification(source, "Mute who??") return False self.logger.info("Player '%s' has muted '%s'" % (source, target)) self.handle_notification(source, "Spectator '%s' has now been muted." % target) self.muted |= {target} self.save_mutelist(source) self.update_watcher_description() return True def unmute(self, source, target): if not self.is_player(source): self.handle_notification(source, "You do not have permission to unmute spectators.") return False if (source == target): self.handle_notification(source, "You can't unmute (or mute) yourself!") return False if target == "*": if (len(self.muted) == 0): self.handle_notification(source, "No one is muted!") return False self.logger.info("Player '%s' has cleared their mute list." % (source)) self.handle_notification(source, "You have cleared your mute list.") self.muted = set() self.save_mutelist(source) self.update_watcher_description() return True if not target in self.muted: self.handle_notification(source, "Unmute who??") return False self.logger.info("Player '%s' has unmuted '%s'" % (source, target)) self.handle_notification(source, "You have unmuted '%s'." % target) self.muted -= {target} self.save_mutelist(source) self.update_watcher_description() return True def show_mute_list(self, source): if not self.is_player(source): return False names = list(self.muted) names.sort(key=lambda s: s.lower()) if len(names) == 0: self.handle_notification(source, "No one is muted.") else: self.handle_notification(source, "You have muted: " + ", ".join(names)) return True def get_anon(self): return [w for w in self._receivers if not w.username] def update_watcher_description(self): player_url_template = config.get('player_url') def wrap_name(watcher, is_player=False): if is_player: class_type = 'player' else: class_type = 'watcher' n = watcher if (watcher in self.muted): n += " (muted)" if player_url_template is None: return "<span class='{0}'>{1}</span>".format(class_type, n) player_url = player_url_template.replace("%s", watcher.lower()) username = "******".format(player_url, class_type, n) return username player_name, watchers = self.get_watchers(True) watcher_names = [] if player_name is not None: watcher_names.append(wrap_name(player_name, True)) watcher_names += [wrap_name(w) for w in watchers] anon_count = len(self.get_anon()) s = ", ".join(watcher_names) if len(watcher_names) > 0 and anon_count > 0: s = s + " and %i Anon" % anon_count elif anon_count > 0: s = "%i Anon" % anon_count self.send_to_all("update_spectators", count = self.watcher_count(), names = s) if config.get('dgl_mode'): update_all_lobbys(self) def add_watcher(self, watcher): self.last_watcher_join = time.time() if self.client_path: self._send_client(watcher) if watcher.watched_game == self: watcher.send_json_options(self.game_params["id"], self.username) self._receivers.add(watcher) self.update_watcher_description() def remove_watcher(self, watcher): self._receivers.remove(watcher) self.update_watcher_description() def watcher_count(self): return len([w for w in self._receivers if w.watched_game and not w.chat_hidden]) def send_client_to_all(self): for receiver in self._receivers: self._send_client(receiver) if receiver.watched_game == self: receiver.send_json_options(self.game_params["id"], self.username) def _send_client(self, watcher): h = hashlib.sha1(utf8(os.path.abspath(self.client_path))) if self.crawl_version: h.update(utf8(self.crawl_version)) v = h.hexdigest() GameDataHandler.add_version(v, os.path.join(self.client_path, "static")) templ_path = os.path.join(self.client_path, "templates") loader = DynamicTemplateLoader.get(templ_path) templ = loader.load("game.html") game_html = to_unicode(templ.generate(version = v)) watcher.send_message("game_client", version = v, content = game_html) def stop(self): if self.process: self.process.send_signal(subprocess.signal.SIGHUP) t = time.time() + config.get('kill_timeout') self.kill_timeout = IOLoop.current().add_timeout(t, self.kill) def kill(self): if self.process: self.logger.info("Killing crawl process after SIGHUP did nothing.") self.process.send_signal(subprocess.signal.SIGABRT) self.kill_timeout = None interesting_info = ("xl", "char", "place", "turn", "dur", "god", "title") def set_where_info(self, newwhere): # milestone doesn't count as "interesting" but the field is directly # handled when sending lobby info by looking at last_milestone milestone = bool(newwhere.get("milestone")) interesting = (milestone or newwhere.get("status") == "chargen" or any([self.where.get(key) != newwhere.get(key) for key in CrawlProcessHandlerBase.interesting_info])) # ignore milestone sync messages for where purposes if newwhere.get("status") != "milestone_only": self.where = newwhere if milestone: self.last_milestone = newwhere if interesting: update_all_lobbys(self) def check_where(self): if self.receiving_direct_milestones: return morgue_path = self.config_path("morgue_path") wherefile = os.path.join(morgue_path, self.username + ".where") try: if os.path.getmtime(wherefile) > self.wheretime: self.wheretime = time.time() with open(wherefile, "r") as f: wheredata = f.read() if wheredata.strip() == "": return try: newwhere = parse_where_data(wheredata) except: self.logger.warning("Exception while trying to parse where file!", exc_info=True) else: if (newwhere.get("status") == "active" or newwhere.get("status") == "saved"): self.set_where_info(newwhere) except (OSError, IOError): pass def lobby_entry(self): entry = { "id": self.id, "username": self.username, "spectator_count": self.watcher_count(), "idle_time": (self.idle_time() if self.is_idle() else 0), "game_id": self.game_params["id"], } for key in CrawlProcessHandlerBase.interesting_info: if key in self.where: entry[key] = self.where[key] if self.last_milestone and self.last_milestone.get("milestone"): entry["milestone"] = self.last_milestone.get("milestone") return entry def human_readable_where(self): try: return "L{xl} {char}, {place}".format(**self.where) except KeyError: return "" def _base_call(self): game = self.game_params call = [game["crawl_binary"]] if "pre_options" in game: call += game["pre_options"] call += ["-name", self.username, "-rc", os.path.join(self.config_path("rcfile_path"), self.username + ".rc"), "-macro", os.path.join(self.config_path("macro_path"), self.username + ".macro"), "-morgue", self.config_path("morgue_path")] if "options" in game: call += game["options"] if "dir_path" in game: call += ["-dir", self.config_path("dir_path")] return call def note_activity(self): self.last_activity_time = time.time() self.check_idle() def handle_input(self, msg): raise NotImplementedError()
class PollingHandler(BaseSocketHandler): """ This class represents separate websocket connection. Attributes: tracker: tornado.ioloop.PeriodicCallback with get_location method as a callback. Starts when user pushes "track" button. When started, it runs every 5 seconds to find out and update character's location. q: tornado.queues.Queue used for running tasks successively. updating: A flag indicates if router is being updated or not. Required to avoid race conditions. """ def __init__(self, *args, **kwargs): super(PollingHandler, self).__init__(*args, **kwargs) # Set Tornado PeriodicCallback with our self.track, we # will use launch it later on track/untrack commands self.tracker = PeriodicCallback(self.get_location, 5000) self.q = Queue(maxsize=5) self.updating = False async def get_location(self): """ The callback for the `self.tracker`. Makes an API call, updates router and sends updated data to the front-end. """ # Call API to find out current character location location = await self.character(self.user_id, '/location/', 'GET') if location: # Set `updating` flag to not accept periodic updates # from front-end, to not overwrite new data self.updating = True user = self.user graph_data = await user['router'].update( location['solarSystem']['name']) if graph_data: message = ['update', graph_data] logging.warning(graph_data) await self.safe_write(message) self.updating = False else: message = ['warning', 'Log into game to track your route'] await self.safe_write(message) async def scheduler(self): """ Scheduler for user tasks. Waits until there is new item in the queue, does task, resolves task. Tornado queues doc: http://www.tornadoweb.org/en/stable/queues.html Since we have no guarantee of the order of the incoming messages (new message from front-end can come before current is done), we need to ensure all tasks to run successively. Here comes the asynchronous generator. """ logging.info(f"Scheduler started for {self.request.remote_ip}") # Wait on each iteration until there's actually an item available async for item in self.q: logging.debug(f"Started resolving task for {item}...") user = self.user try: if item == 'recover': # Send saved route await self.safe_write(['recover', user['router'].recovery]) elif item == 'track': # Start the PeriodicCallback if not self.tracker.is_running(): self.tracker.start() elif item in ['stop', 'reset']: # Stop the PeriodicCallback if self.tracker.is_running(): self.tracker.stop() # Clear all saved data if item == 'reset': await user['router'].reset() elif item[0] == 'backup': # Do not overwrite user object while it's updating, # just in case, to avoid race conditions. if not self.updating: await user['router'].backup(item[1]) finally: self.q.task_done() logging.debug(f'Task "{item}" done.') async def task(self, item): """ Intermediary between `self.on_message` and `self.scheduler`. Since we cannot do anything asynchronous in the `self.on_message`, this method can handle any additional non-blocking stuff if we need it. :argument item: item to pass to the `self.scheduler`. """ await self.q.put(item) #await self.q.join() def open(self): """ Triggers on successful websocket connection. Ensures user is authorized, spawns `self.scheduler` for user tasks, adds this websocket object to the connections pool, spawns the recovery of the saved route. """ logging.info(f"Connection received from {self.request.remote_ip}") if self.user_id: self.spawn(self.scheduler) self.vagrants.append(self) self.spawn(self.task, 'recover') else: self.close() def on_message(self, message): """ Triggers on receiving front-end message. :argument message: front-end message. Receives user commands and passes them to the `self.scheduler` via `self.task`. """ self.spawn(self.task, json_decode(message)) def on_close(self): """ Triggers on closed websocket connection. Removes this websocket object from the connections pool, stops `self.tracker` if it is running. """ self.vagrants.remove(self) if self.tracker.is_running(): self.tracker.stop() logging.info("Connection closed, " + self.request.remote_ip)
class PeriodicCallback(param.Parameterized): """ Periodic encapsulates a periodic callback which will run both in tornado based notebook environments and on bokeh server. By default the callback will run until the stop method is called, but count and timeout values can be set to limit the number of executions or the maximum length of time for which the callback will run. """ callback = param.Callable(doc=""" The callback to execute periodically.""") count = param.Integer(default=None, doc=""" Number of times the callback will be executed, by default this is unlimited.""") period = param.Integer(default=500, doc=""" Period in milliseconds at which the callback is executed.""") timeout = param.Integer(default=None, doc=""" Timeout in seconds from the start time at which the callback expires""") def __init__(self, **params): super(PeriodicCallback, self).__init__(**params) self._counter = 0 self._start_time = None self._timeout = None self._cb = None self._doc = None def start(self): if self._cb is not None: raise RuntimeError('Periodic callback has already started.') self._start_time = time.time() if _curdoc().session_context: self._doc = _curdoc() self._cb = self._doc.add_periodic_callback(self._periodic_callback, self.period) else: from tornado.ioloop import PeriodicCallback self._cb = PeriodicCallback(self._periodic_callback, self.period) self._cb.start() def _periodic_callback(self): self.callback() self._counter += 1 if self._timeout is not None: dt = (time.time() - self._start_time) if dt > self._timeout: self.stop() if self._counter == self.count: self.stop() def stop(self): self._counter = 0 self._timeout = None if self._doc: self._doc.remove_periodic_callback(self._cb) else: self._cb.stop() self._cb = None
class BokehTornado(TornadoApplication): ''' A Tornado Application used to implement the Bokeh Server. Args: applications (dict[str,Application] or Application) : A map from paths to ``Application`` instances. If the value is a single Application, then the following mapping is generated: .. code-block:: python applications = {{ '/' : applications }} When a connection comes in to a given path, the associate Application is used to generate a new document for the session. prefix (str, optional) : A URL prefix to use for all Bokeh server paths. (default: None) extra_websocket_origins (list[str], optional) : A list of hosts that can connect to the websocket. This is typically required when embedding a Bokeh server app in an external web site using :func:`~bokeh.embed.server_document` or similar. If None, ``["localhost"]`` will be assumed (default: None) extra_patterns (seq[tuple], optional) : A list of tuples of (str, http or websocket handler) Use this argument to add additional endpoints to custom deployments of the Bokeh Server. If None, then ``[]`` will be used. (default: None) secret_key (str, optional) : A secret key for signing session IDs. Defaults to the current value of the environment variable ``BOKEH_SECRET_KEY`` sign_sessions (bool, optional) : Whether to cryptographically sign session IDs Defaults to the current value of the environment variable ``BOKEH_SIGN_SESSIONS``. If ``True``, then ``secret_key`` must also be provided (either via environment setting or passed as a parameter value) generate_session_ids (bool, optional) : Whether to generate a session ID if one is not provided (default: True) keep_alive_milliseconds (int, optional) : Number of milliseconds between keep-alive pings (default: {DEFAULT_KEEP_ALIVE_MS}) Pings normally required to keep the websocket open. Set to 0 to disable pings. check_unused_sessions_milliseconds (int, optional) : Number of milliseconds between checking for unused sessions (default: {DEFAULT_CHECK_UNUSED_MS}) unused_session_lifetime_milliseconds (int, optional) : Number of milliseconds for unused session lifetime (default: {DEFAULT_UNUSED_LIFETIME_MS}) stats_log_frequency_milliseconds (int, optional) : Number of milliseconds between logging stats (default: {DEFAULT_STATS_LOG_FREQ_MS}) mem_log_frequency_milliseconds (int, optional) : Number of milliseconds between logging memory information (default: {DEFAULT_MEM_LOG_FREQ_MS}) Enabling this feature requires the optional dependency ``psutil`` to be installed. use_index (bool, optional) : Whether to generate an index of running apps in the ``RootHandler`` (default: True) index (str, optional) : Path to a Jinja2 template to serve as the index for "/" if use_index is True. If None, the basic built in app index template is used. (default: None) redirect_root (bool, optional) : When there is only a single running application, whether to redirect requests to ``"/"`` to that application automatically (default: True) If there are multiple Bokeh applications configured, this option has no effect. websocket_max_message_size_bytes (int, optional): Set the Tornado ``websocket_max_message_size`` value. (default: {DEFAULT_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES}) index (str, optional): Path to a Jinja2 template to use for the root URL auth_provider (AuthProvider, optional): An AuthProvider instance Any additional keyword arguments are passed to ``tornado.web.Application``. ''' def __init__(self, applications, prefix=None, extra_websocket_origins=None, extra_patterns=None, secret_key=settings.secret_key_bytes(), sign_sessions=settings.sign_sessions(), generate_session_ids=True, keep_alive_milliseconds=DEFAULT_KEEP_ALIVE_MS, check_unused_sessions_milliseconds=DEFAULT_CHECK_UNUSED_MS, unused_session_lifetime_milliseconds=DEFAULT_UNUSED_LIFETIME_MS, stats_log_frequency_milliseconds=DEFAULT_STATS_LOG_FREQ_MS, mem_log_frequency_milliseconds=DEFAULT_MEM_LOG_FREQ_MS, use_index=True, redirect_root=True, websocket_max_message_size_bytes=DEFAULT_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES, index=None, auth_provider=NullAuth(), xsrf_cookies=False, **kwargs): # This will be set when initialize is called self._loop = None if isinstance(applications, Application): applications = { '/' : applications } if prefix is None: prefix = "" prefix = prefix.strip("/") if prefix: prefix = "/" + prefix self._prefix = prefix self._index = index if keep_alive_milliseconds < 0: # 0 means "disable" raise ValueError("keep_alive_milliseconds must be >= 0") else: if keep_alive_milliseconds == 0: log.info("Keep-alive ping disabled") elif keep_alive_milliseconds != DEFAULT_KEEP_ALIVE_MS: log.info("Keep-alive ping configured every %d milliseconds", keep_alive_milliseconds) self._keep_alive_milliseconds = keep_alive_milliseconds if check_unused_sessions_milliseconds <= 0: raise ValueError("check_unused_sessions_milliseconds must be > 0") elif check_unused_sessions_milliseconds != DEFAULT_CHECK_UNUSED_MS: log.info("Check for unused sessions every %d milliseconds", check_unused_sessions_milliseconds) self._check_unused_sessions_milliseconds = check_unused_sessions_milliseconds if unused_session_lifetime_milliseconds <= 0: raise ValueError("check_unused_sessions_milliseconds must be > 0") elif unused_session_lifetime_milliseconds != DEFAULT_UNUSED_LIFETIME_MS: log.info("Unused sessions last for %d milliseconds", unused_session_lifetime_milliseconds) self._unused_session_lifetime_milliseconds = unused_session_lifetime_milliseconds if stats_log_frequency_milliseconds <= 0: raise ValueError("stats_log_frequency_milliseconds must be > 0") elif stats_log_frequency_milliseconds != DEFAULT_STATS_LOG_FREQ_MS: log.info("Log statistics every %d milliseconds", stats_log_frequency_milliseconds) self._stats_log_frequency_milliseconds = stats_log_frequency_milliseconds if mem_log_frequency_milliseconds < 0: # 0 means "disable" raise ValueError("mem_log_frequency_milliseconds must be >= 0") elif mem_log_frequency_milliseconds > 0: if import_optional('psutil') is None: log.warning("Memory logging requested, but is disabled. Optional dependency 'psutil' is missing. " "Try 'pip install psutil' or 'conda install psutil'") mem_log_frequency_milliseconds = 0 elif mem_log_frequency_milliseconds != DEFAULT_MEM_LOG_FREQ_MS: log.info("Log memory usage every %d milliseconds", mem_log_frequency_milliseconds) self._mem_log_frequency_milliseconds = mem_log_frequency_milliseconds if websocket_max_message_size_bytes <= 0: raise ValueError("websocket_max_message_size_bytes must be positive") elif websocket_max_message_size_bytes != DEFAULT_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES: log.info("Torndado websocket_max_message_size set to %d bytes (%0.2f MB)", websocket_max_message_size_bytes, websocket_max_message_size_bytes/1024.0**2) self.auth_provider = auth_provider if self.auth_provider.get_user or self.auth_provider.get_user_async: log.info("User authentication hooks provided (no default user)") else: log.info("User authentication hooks NOT provided (default user enabled)") kwargs['xsrf_cookies'] = xsrf_cookies if xsrf_cookies: log.info("XSRF cookie protection enabled") if extra_websocket_origins is None: self._websocket_origins = set() else: self._websocket_origins = set(extra_websocket_origins) self._secret_key = secret_key self._sign_sessions = sign_sessions self._generate_session_ids = generate_session_ids log.debug("These host origins can connect to the websocket: %r", list(self._websocket_origins)) # Wrap applications in ApplicationContext self._applications = dict() for k,v in applications.items(): self._applications[k] = ApplicationContext(v, url=k, logout_url=self.auth_provider.logout_url) extra_patterns = extra_patterns or [] extra_patterns.extend(self.auth_provider.endpoints) all_patterns = [] for key, app in applications.items(): app_patterns = [] for p in per_app_patterns: if key == "/": route = p[0] else: route = key + p[0] route = self._prefix + route app_patterns.append((route, p[1], { "application_context" : self._applications[key] })) websocket_path = None for r in app_patterns: if r[0].endswith("/ws"): websocket_path = r[0] if not websocket_path: raise RuntimeError("Couldn't find websocket path") for r in app_patterns: r[2]["bokeh_websocket_path"] = websocket_path all_patterns.extend(app_patterns) # add a per-app static path if requested by the application if app.static_path is not None: if key == "/": route = "/static/(.*)" else: route = key + "/static/(.*)" route = self._prefix + route all_patterns.append((route, StaticFileHandler, { "path" : app.static_path })) for p in extra_patterns + toplevel_patterns: if p[1] == RootHandler: if use_index: data = {"applications": self._applications, "prefix": self._prefix, "index": self._index, "use_redirect": redirect_root} prefixed_pat = (self._prefix + p[0],) + p[1:] + (data,) all_patterns.append(prefixed_pat) else: prefixed_pat = (self._prefix + p[0],) + p[1:] all_patterns.append(prefixed_pat) log.debug("Patterns are:") for line in pformat(all_patterns, width=60).split("\n"): log.debug(" " + line) super().__init__(all_patterns, websocket_max_message_size=websocket_max_message_size_bytes, **kwargs) def initialize(self, io_loop): ''' Start a Bokeh Server Tornado Application on a given Tornado IOLoop. ''' self._loop = io_loop for app_context in self._applications.values(): app_context._loop = self._loop self._clients = set() self._stats_job = PeriodicCallback(self._log_stats, self._stats_log_frequency_milliseconds) if self._mem_log_frequency_milliseconds > 0: self._mem_job = PeriodicCallback(self._log_mem, self._mem_log_frequency_milliseconds) else: self._mem_job = None self._cleanup_job = PeriodicCallback(self._cleanup_sessions, self._check_unused_sessions_milliseconds) if self._keep_alive_milliseconds > 0: self._ping_job = PeriodicCallback(self._keep_alive, self._keep_alive_milliseconds) else: self._ping_job = None @property def app_paths(self): ''' A list of all application paths for all Bokeh applications configured on this Bokeh server instance. ''' return set(self._applications) @property def index(self): ''' Path to a Jinja2 template to serve as the index "/" ''' return self._index @property def io_loop(self): ''' The Tornado IOLoop that this Bokeh Server Tornado Application is running on. ''' return self._loop @property def prefix(self): ''' A URL prefix for this Bokeh Server Tornado Application to use for all paths ''' return self._prefix @property def websocket_origins(self): ''' A list of websocket origins permitted to connect to this server. ''' return self._websocket_origins @property def secret_key(self): ''' A secret key for this Bokeh Server Tornado Application to use when signing session IDs, if configured. ''' return self._secret_key @property def sign_sessions(self): ''' Whether this Bokeh Server Tornado Application has been configured to cryptographically sign session IDs If ``True``, then ``secret_key`` must also have been configured. ''' return self._sign_sessions @property def generate_session_ids(self): ''' Whether this Bokeh Server Tornado Application has been configured to automatically generate session IDs. ''' return self._generate_session_ids def resources(self, absolute_url=None): ''' Provide a :class:`~bokeh.resources.Resources` that specifies where Bokeh application sessions should load BokehJS resources from. Args: absolute_url (bool): An absolute URL prefix to use for locating resources. If None, relative URLs are used (default: None) ''' mode = settings.resources(default="server") if mode == "server": root_url = urljoin(absolute_url, self._prefix) if absolute_url else self._prefix return Resources(mode="server", root_url=root_url, path_versioner=StaticHandler.append_version) return Resources(mode=mode) def start(self): ''' Start the Bokeh Server application. Starting the Bokeh Server Tornado application will run periodic callbacks for stats logging, cleanup, pinging, etc. Additionally, any startup hooks defined by the configured Bokeh applications will be run. ''' self._stats_job.start() if self._mem_job is not None: self._mem_job.start() self._cleanup_job.start() if self._ping_job is not None: self._ping_job.start() for context in self._applications.values(): self._loop.spawn_callback(context.run_load_hook) def stop(self, wait=True): ''' Stop the Bokeh Server application. Args: wait (bool): whether to wait for orderly cleanup (default: True) Returns: None ''' # TODO should probably close all connections and shut down all sessions here for context in self._applications.values(): context.run_unload_hook() self._stats_job.stop() if self._mem_job is not None: self._mem_job.stop() self._cleanup_job.stop() if self._ping_job is not None: self._ping_job.stop() self._clients.clear() def new_connection(self, protocol, socket, application_context, session): connection = ServerConnection(protocol, socket, application_context, session) self._clients.add(connection) return connection def client_lost(self, connection): self._clients.discard(connection) connection.detach_session() def get_session(self, app_path, session_id): ''' Get an active a session by name application path and session ID. Args: app_path (str) : The configured application path for the application to return a session for. session_id (str) : The session ID of the session to retrieve. Returns: ServerSession ''' if app_path not in self._applications: raise ValueError("Application %s does not exist on this server" % app_path) return self._applications[app_path].get_session(session_id) def get_sessions(self, app_path): ''' Gets all currently active sessions for an application. Args: app_path (str) : The configured application path for the application to return sessions for. Returns: list[ServerSession] ''' if app_path not in self._applications: raise ValueError("Application %s does not exist on this server" % app_path) return list(self._applications[app_path].sessions) # Periodic Callbacks ------------------------------------------------------ async def _cleanup_sessions(self): log.trace("Running session cleanup job") for app in self._applications.values(): await app._cleanup_sessions(self._unused_session_lifetime_milliseconds) return None def _log_stats(self): log.trace("Running stats log job") if log.getEffectiveLevel() > logging.DEBUG: # avoid the work below if we aren't going to log anything return log.debug("[pid %d] %d clients connected", os.getpid(), len(self._clients)) for app_path, app in self._applications.items(): sessions = list(app.sessions) unused_count = 0 for s in sessions: if s.connection_count == 0: unused_count += 1 log.debug("[pid %d] %s has %d sessions with %d unused", os.getpid(), app_path, len(sessions), unused_count) def _log_mem(self): import psutil process = psutil.Process(os.getpid()) log.info("[pid %d] Memory usage: %0.2f MB (RSS), %0.2f MB (VMS)", os.getpid(), process.memory_info().rss//2**20, process.memory_info().vms//2**20) if log.getEffectiveLevel() > logging.DEBUG: # avoid the work below if we aren't going to log anything else return import gc from ..document import Document from ..model import Model from .session import ServerSession for name, typ in [('Documents', Document), ('Sessions', ServerSession), ('Models', Model)]: objs = [x for x in gc.get_objects() if isinstance(x, typ)] log.debug(" uncollected %s: %d", name, len(objs)) def _keep_alive(self): log.trace("Running keep alive job") for c in self._clients: c.send_ping()
class EventHandler: events_enable_interval = 5000 # in seconds # Maximum number of finished items to keep track of max_finished_history = 1000 # celery events that represent a task finishing finished_events = ( 'task-succeeded', 'task-failed', 'task-rejected', 'task-revoked', ) def __init__(self, capp, io_loop): """Monitors events that are received from celery. capp - The celery app io_loop - The event loop to use for dispatch """ super().__init__() self.capp = capp self.timer = PeriodicCallback(self.on_enable_events, self.events_enable_interval) self.monitor = EventMonitor(self.capp, io_loop) self.listeners = {} self.finished_tasks = LRUCache(self.max_finished_history) @tornado.gen.coroutine def start(self): """Start event handler. Expects to be run as a coroutine. """ self.timer.start() logger.debug('Starting celery monitor thread') self.monitor.start() while True: event = yield self.monitor.events.get() try: task_id = event['uuid'] except KeyError: continue # Record finished tasks in-case they are requested # too late or are re-requested. if event['type'] in self.finished_events: self.finished_tasks[task_id] = event try: callback = self.listeners[task_id] except KeyError: pass else: callback(event) def stop(self): self.timer.stop() # FIXME: can not be stopped gracefully # self.monitor.stop() def on_enable_events(self): """Called periodically to enable events for workers launched after the monitor. """ try: self.capp.control.enable_events() except Exception as e: logger.debug('Failed to enable events: %s', e) def add_listener(self, task_id, callback): """Add event listener for a task with ID `task_id`.""" try: event = self.finished_tasks[task_id] except KeyError: self.listeners[task_id] = callback else: # Task has already finished callback(event) def remove_listener(self, task_id): """Remove listener for `task_id`.""" try: del self.listeners[task_id] except KeyError: # may have been cached pass
class BokehTornado(TornadoApplication): ''' A Tornado Application used to implement the Bokeh Server. The Server class is the main public interface, this class has Tornado implementation details. Args: applications (dict of str : bokeh.application.Application) : map from paths to Application instances The application is used to create documents for each session. extra_patterns (seq[tuple]) : tuples of (str, http or websocket handler) Use this argument to add additional endpoints to custom deployments of the Bokeh Server. prefix (str) : a URL prefix to use for all Bokeh server paths hosts (list) : hosts that are valid values for the Host header secret_key (str) : secret key for signing session IDs sign_sessions (boolean) : whether to sign session IDs generate_session_ids (boolean) : whether to generate a session ID when none is provided extra_websocket_origins (list) : hosts that can connect to the websocket These are in addition to ``hosts``. keep_alive_milliseconds (int) : number of milliseconds between keep-alive pings Set to 0 to disable pings. Pings keep the websocket open. develop (boolean) : True for develop mode ''' def __init__(self, applications, prefix, hosts, extra_websocket_origins, io_loop=None, extra_patterns=None, secret_key=settings.secret_key_bytes(), sign_sessions=settings.sign_sessions(), generate_session_ids=True, # heroku, nginx default to 60s timeout, so well less than that keep_alive_milliseconds=37000, # how often to check for unused sessions check_unused_sessions_milliseconds=17000, # how long unused sessions last unused_session_lifetime_milliseconds=60*30*1000, # how often to log stats stats_log_frequency_milliseconds=15000, develop=False): self._prefix = prefix if io_loop is None: io_loop = IOLoop.current() self._loop = io_loop if keep_alive_milliseconds < 0: # 0 means "disable" raise ValueError("keep_alive_milliseconds must be >= 0") self._hosts = set(hosts) self._websocket_origins = self._hosts | set(extra_websocket_origins) self._resources = {} self._develop = develop self._secret_key = secret_key self._sign_sessions = sign_sessions self._generate_session_ids = generate_session_ids log.debug("Allowed Host headers: %r", list(self._hosts)) log.debug("These host origins can connect to the websocket: %r", list(self._websocket_origins)) # Wrap applications in ApplicationContext self._applications = dict() for k,v in applications.items(): self._applications[k] = ApplicationContext(v, self._develop, self._loop) extra_patterns = extra_patterns or [] all_patterns = [] for key in applications: app_patterns = [] for p in per_app_patterns: if key == "/": route = p[0] else: route = key + p[0] route = self._prefix + route app_patterns.append((route, p[1], { "application_context" : self._applications[key] })) websocket_path = None for r in app_patterns: if r[0].endswith("/ws"): websocket_path = r[0] if not websocket_path: raise RuntimeError("Couldn't find websocket path") for r in app_patterns: r[2]["bokeh_websocket_path"] = websocket_path all_patterns.extend(app_patterns) for p in extra_patterns + toplevel_patterns: prefixed_pat = (self._prefix+p[0],) + p[1:] all_patterns.append(prefixed_pat) for pat in all_patterns: _whitelist(pat[1]) log.debug("Patterns are: %r", all_patterns) super(BokehTornado, self).__init__(all_patterns) self._clients = set() self._executor = ProcessPoolExecutor(max_workers=4) self._loop.add_callback(self._start_async) self._stats_job = PeriodicCallback(self.log_stats, stats_log_frequency_milliseconds, io_loop=self._loop) self._unused_session_linger_seconds = unused_session_lifetime_milliseconds self._cleanup_job = PeriodicCallback(self.cleanup_sessions, check_unused_sessions_milliseconds, io_loop=self._loop) if keep_alive_milliseconds > 0: self._ping_job = PeriodicCallback(self.keep_alive, keep_alive_milliseconds, io_loop=self._loop) else: self._ping_job = None @property def io_loop(self): return self._loop @property def websocket_origins(self): return self._websocket_origins @property def secret_key(self): return self._secret_key @property def sign_sessions(self): return self._sign_sessions @property def generate_session_ids(self): return self._generate_session_ids def root_url_for_request(self, request): return request.protocol + "://" + request.host + self._prefix + "/" def websocket_url_for_request(self, request, websocket_path): # websocket_path comes from the handler, and already has any # prefix included, no need to add here protocol = "ws" if request.protocol == "https": protocol = "wss" return protocol + "://" + request.host + websocket_path def resources(self, request): root_url = self.root_url_for_request(request) if root_url not in self._resources: self._resources[root_url] = Resources(mode="server", root_url=root_url, path_versioner=StaticHandler.append_version) return self._resources[root_url] def start(self, start_loop=True): ''' Start the Bokeh Server application main loop. Args: start_loop (boolean): False to not actually start event loop, used in tests Returns: None Notes: Keyboard interrupts or sigterm will cause the server to shut down. ''' self._stats_job.start() self._cleanup_job.start() if self._ping_job is not None: self._ping_job.start() for context in self._applications.values(): context.run_load_hook() if start_loop: try: self._loop.start() except KeyboardInterrupt: print("\nInterrupted, shutting down") def stop(self): ''' Stop the Bokeh Server application. Returns: None ''' # TODO we should probably close all connections and shut # down all sessions either here or in unlisten() ... but # it isn't that important since in real life it's rare to # do a clean shutdown (vs. a kill-by-signal) anyhow. for context in self._applications.values(): context.run_unload_hook() self._stats_job.stop() self._cleanup_job.stop() if self._ping_job is not None: self._ping_job.stop() self._loop.stop() @property def executor(self): return self._executor def new_connection(self, protocol, socket, application_context, session): connection = ServerConnection(protocol, socket, application_context, session) self._clients.add(connection) return connection def client_lost(self, connection): self._clients.discard(connection) connection.detach_session() def get_session(self, app_path, session_id): if app_path not in self._applications: raise ValueError("Application %s does not exist on this server" % app_path) return self._applications[app_path].get_session(session_id) def get_sessions(self, app_path): if app_path not in self._applications: raise ValueError("Application %s does not exist on this server" % app_path) return list(self._applications[app_path].sessions) @gen.coroutine def cleanup_sessions(self): for app in self._applications.values(): yield app.cleanup_sessions(self._unused_session_linger_seconds) raise gen.Return(None) def log_stats(self): if log.getEffectiveLevel() > logging.DEBUG: # avoid the work below if we aren't going to log anything return log.debug("[pid %d] %d clients connected", os.getpid(), len(self._clients)) for app_path, app in self._applications.items(): sessions = list(app.sessions) unused_count = 0 for s in sessions: if s.connection_count == 0: unused_count += 1 log.debug("[pid %d] %s has %d sessions with %d unused", os.getpid(), app_path, len(sessions), unused_count) def keep_alive(self): for c in self._clients: c.send_ping() @gen.coroutine def run_in_background(self, _func, *args, **kwargs): """ Run a synchronous function in the background without disrupting the main thread. Useful for long-running jobs. """ res = yield self._executor.submit(_func, *args, **kwargs) raise gen.Return(res) @gen.coroutine def _start_async(self): try: atexit.register(self._atexit) signal.signal(signal.SIGTERM, self._sigterm) except Exception: self.exit(1) _atexit_ran = False def _atexit(self): if self._atexit_ran: return self._atexit_ran = True self._stats_job.stop() IOLoop.clear_current() loop = IOLoop() loop.make_current() loop.run_sync(self._cleanup) def _sigterm(self, signum, frame): print("Received SIGTERM, shutting down") self.stop() self._atexit() @gen.coroutine def _cleanup(self): log.debug("Shutdown: cleaning up") self._executor.shutdown(wait=False) self._clients.clear()
class WebSocketProtocol13(WebSocketProtocol): """Implementation of the WebSocket protocol from RFC 6455. This class supports versions 7 and 8 of the protocol in addition to the final version 13. """ # Bit masks for the first byte of a frame. FIN = 0x80 RSV1 = 0x40 RSV2 = 0x20 RSV3 = 0x10 RSV_MASK = RSV1 | RSV2 | RSV3 OPCODE_MASK = 0x0f def __init__(self, handler, mask_outgoing=False, compression_options=None): WebSocketProtocol.__init__(self, handler) self.mask_outgoing = mask_outgoing self._final_frame = False self._frame_opcode = None self._masked_frame = None self._frame_mask = None self._frame_length = None self._fragmented_message_buffer = None self._fragmented_message_opcode = None self._waiting = None self._compression_options = compression_options self._decompressor = None self._compressor = None self._frame_compressed = None # The total uncompressed size of all messages received or sent. # Unicode messages are encoded to utf8. # Only for testing; subject to change. self._message_bytes_in = 0 self._message_bytes_out = 0 # The total size of all packets received or sent. Includes # the effect of compression, frame overhead, and control frames. self._wire_bytes_in = 0 self._wire_bytes_out = 0 self.ping_callback = None self.last_ping = 0 self.last_pong = 0 def accept_connection(self): try: self._handle_websocket_headers() except ValueError: self.handler.set_status(400) log_msg = "Missing/Invalid WebSocket headers" self.handler.finish(log_msg) gen_log.debug(log_msg) return try: self._accept_connection() except ValueError: gen_log.debug("Malformed WebSocket request received", exc_info=True) self._abort() return def _handle_websocket_headers(self): """Verifies all invariant- and required headers If a header is missing or have an incorrect value ValueError will be raised """ fields = ("Host", "Sec-Websocket-Key", "Sec-Websocket-Version") if not all(map(lambda f: self.request.headers.get(f), fields)): raise ValueError("Missing/Invalid WebSocket headers") @staticmethod def compute_accept_value(key): """Computes the value for the Sec-WebSocket-Accept header, given the value for Sec-WebSocket-Key. """ sha1 = hashlib.sha1() sha1.update(utf8(key)) sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11") # Magic value return native_str(base64.b64encode(sha1.digest())) def _challenge_response(self): return WebSocketProtocol13.compute_accept_value( self.request.headers.get("Sec-Websocket-Key")) def _accept_connection(self): subprotocols = self.request.headers.get("Sec-WebSocket-Protocol", '') subprotocols = [s.strip() for s in subprotocols.split(',')] if subprotocols: selected = self.handler.select_subprotocol(subprotocols) if selected: assert selected in subprotocols self.handler.set_header("Sec-WebSocket-Protocol", selected) extensions = self._parse_extensions_header(self.request.headers) for ext in extensions: if (ext[0] == 'permessage-deflate' and self._compression_options is not None): # TODO: negotiate parameters if compression_options # specifies limits. self._create_compressors('server', ext[1], self._compression_options) if ('client_max_window_bits' in ext[1] and ext[1]['client_max_window_bits'] is None): # Don't echo an offered client_max_window_bits # parameter with no value. del ext[1]['client_max_window_bits'] self.handler.set_header("Sec-WebSocket-Extensions", httputil._encode_header( 'permessage-deflate', ext[1])) break self.handler.clear_header("Content-Type") self.handler.set_status(101) self.handler.set_header("Upgrade", "websocket") self.handler.set_header("Connection", "Upgrade") self.handler.set_header("Sec-WebSocket-Accept", self._challenge_response()) self.handler.finish() self.handler._attach_stream() self.stream = self.handler.stream self.start_pinging() self._run_callback(self.handler.open, *self.handler.open_args, **self.handler.open_kwargs) self._receive_frame() def _parse_extensions_header(self, headers): extensions = headers.get("Sec-WebSocket-Extensions", '') if extensions: return [httputil._parse_header(e.strip()) for e in extensions.split(',')] return [] def _process_server_headers(self, key, headers): """Process the headers sent by the server to this client connection. 'key' is the websocket handshake challenge/response key. """ assert headers['Upgrade'].lower() == 'websocket' assert headers['Connection'].lower() == 'upgrade' accept = self.compute_accept_value(key) assert headers['Sec-Websocket-Accept'] == accept extensions = self._parse_extensions_header(headers) for ext in extensions: if (ext[0] == 'permessage-deflate' and self._compression_options is not None): self._create_compressors('client', ext[1]) else: raise ValueError("unsupported extension %r", ext) def _get_compressor_options(self, side, agreed_parameters, compression_options=None): """Converts a websocket agreed_parameters set to keyword arguments for our compressor objects. """ options = dict( persistent=(side + '_no_context_takeover') not in agreed_parameters) wbits_header = agreed_parameters.get(side + '_max_window_bits', None) if wbits_header is None: options['max_wbits'] = zlib.MAX_WBITS else: options['max_wbits'] = int(wbits_header) options['compression_options'] = compression_options return options def _create_compressors(self, side, agreed_parameters, compression_options=None): # TODO: handle invalid parameters gracefully allowed_keys = set(['server_no_context_takeover', 'client_no_context_takeover', 'server_max_window_bits', 'client_max_window_bits']) for key in agreed_parameters: if key not in allowed_keys: raise ValueError("unsupported compression parameter %r" % key) other_side = 'client' if (side == 'server') else 'server' self._compressor = _PerMessageDeflateCompressor( **self._get_compressor_options(side, agreed_parameters, compression_options)) self._decompressor = _PerMessageDeflateDecompressor( **self._get_compressor_options(other_side, agreed_parameters, compression_options)) def _write_frame(self, fin, opcode, data, flags=0): if fin: finbit = self.FIN else: finbit = 0 frame = struct.pack("B", finbit | opcode | flags) data_len = len(data) if self.mask_outgoing: mask_bit = 0x80 else: mask_bit = 0 if data_len < 126: frame += struct.pack("B", data_len | mask_bit) elif data_len <= 0xFFFF: frame += struct.pack("!BH", 126 | mask_bit, data_len) else: frame += struct.pack("!BQ", 127 | mask_bit, data_len) if self.mask_outgoing: mask = os.urandom(4) data = mask + _websocket_mask(mask, data) frame += data self._wire_bytes_out += len(frame) return self.stream.write(frame) def write_message(self, message, binary=False): """Sends the given message to the client of this Web Socket.""" if binary: opcode = 0x2 else: opcode = 0x1 message = tornado.escape.utf8(message) assert isinstance(message, bytes) self._message_bytes_out += len(message) flags = 0 if self._compressor: message = self._compressor.compress(message) flags |= self.RSV1 # For historical reasons, write methods in Tornado operate in a semi-synchronous # mode in which awaiting the Future they return is optional (But errors can # still be raised). This requires us to go through an awkward dance here # to transform the errors that may be returned while presenting the same # semi-synchronous interface. try: fut = self._write_frame(True, opcode, message, flags=flags) except StreamClosedError: raise WebSocketClosedError() @gen.coroutine def wrapper(): try: yield fut except StreamClosedError: raise WebSocketClosedError() return wrapper() def write_ping(self, data): """Send ping frame.""" assert isinstance(data, bytes) self._write_frame(True, 0x9, data) def _receive_frame(self): try: self.stream.read_bytes(2, self._on_frame_start) except StreamClosedError: self._abort() def _on_frame_start(self, data): self._wire_bytes_in += len(data) header, payloadlen = struct.unpack("BB", data) self._final_frame = header & self.FIN reserved_bits = header & self.RSV_MASK self._frame_opcode = header & self.OPCODE_MASK self._frame_opcode_is_control = self._frame_opcode & 0x8 if self._decompressor is not None and self._frame_opcode != 0: self._frame_compressed = bool(reserved_bits & self.RSV1) reserved_bits &= ~self.RSV1 if reserved_bits: # client is using as-yet-undefined extensions; abort self._abort() return self._masked_frame = bool(payloadlen & 0x80) payloadlen = payloadlen & 0x7f if self._frame_opcode_is_control and payloadlen >= 126: # control frames must have payload < 126 self._abort() return try: if payloadlen < 126: self._frame_length = payloadlen if self._masked_frame: self.stream.read_bytes(4, self._on_masking_key) else: self._read_frame_data(False) elif payloadlen == 126: self.stream.read_bytes(2, self._on_frame_length_16) elif payloadlen == 127: self.stream.read_bytes(8, self._on_frame_length_64) except StreamClosedError: self._abort() def _read_frame_data(self, masked): new_len = self._frame_length if self._fragmented_message_buffer is not None: new_len += len(self._fragmented_message_buffer) if new_len > (self.handler.max_message_size or 10 * 1024 * 1024): self.close(1009, "message too big") return self.stream.read_bytes( self._frame_length, self._on_masked_frame_data if masked else self._on_frame_data) def _on_frame_length_16(self, data): self._wire_bytes_in += len(data) self._frame_length = struct.unpack("!H", data)[0] try: if self._masked_frame: self.stream.read_bytes(4, self._on_masking_key) else: self._read_frame_data(False) except StreamClosedError: self._abort() def _on_frame_length_64(self, data): self._wire_bytes_in += len(data) self._frame_length = struct.unpack("!Q", data)[0] try: if self._masked_frame: self.stream.read_bytes(4, self._on_masking_key) else: self._read_frame_data(False) except StreamClosedError: self._abort() def _on_masking_key(self, data): self._wire_bytes_in += len(data) self._frame_mask = data try: self._read_frame_data(True) except StreamClosedError: self._abort() def _on_masked_frame_data(self, data): # Don't touch _wire_bytes_in; we'll do it in _on_frame_data. self._on_frame_data(_websocket_mask(self._frame_mask, data)) def _on_frame_data(self, data): handled_future = None self._wire_bytes_in += len(data) if self._frame_opcode_is_control: # control frames may be interleaved with a series of fragmented # data frames, so control frames must not interact with # self._fragmented_* if not self._final_frame: # control frames must not be fragmented self._abort() return opcode = self._frame_opcode elif self._frame_opcode == 0: # continuation frame if self._fragmented_message_buffer is None: # nothing to continue self._abort() return self._fragmented_message_buffer += data if self._final_frame: opcode = self._fragmented_message_opcode data = self._fragmented_message_buffer self._fragmented_message_buffer = None else: # start of new data message if self._fragmented_message_buffer is not None: # can't start new message until the old one is finished self._abort() return if self._final_frame: opcode = self._frame_opcode else: self._fragmented_message_opcode = self._frame_opcode self._fragmented_message_buffer = data if self._final_frame: handled_future = self._handle_message(opcode, data) if not self.client_terminated: if handled_future: # on_message is a coroutine, process more frames once it's done. handled_future.add_done_callback( lambda future: self._receive_frame()) else: self._receive_frame() def _handle_message(self, opcode, data): """Execute on_message, returning its Future if it is a coroutine.""" if self.client_terminated: return if self._frame_compressed: data = self._decompressor.decompress(data) if opcode == 0x1: # UTF-8 data self._message_bytes_in += len(data) try: decoded = data.decode("utf-8") except UnicodeDecodeError: self._abort() return return self._run_callback(self.handler.on_message, decoded) elif opcode == 0x2: # Binary data self._message_bytes_in += len(data) return self._run_callback(self.handler.on_message, data) elif opcode == 0x8: # Close self.client_terminated = True if len(data) >= 2: self.handler.close_code = struct.unpack('>H', data[:2])[0] if len(data) > 2: self.handler.close_reason = to_unicode(data[2:]) # Echo the received close code, if any (RFC 6455 section 5.5.1). self.close(self.handler.close_code) elif opcode == 0x9: # Ping try: self._write_frame(True, 0xA, data) except StreamClosedError: self._abort() self._run_callback(self.handler.on_ping, data) elif opcode == 0xA: # Pong self.last_pong = IOLoop.current().time() return self._run_callback(self.handler.on_pong, data) else: self._abort() def close(self, code=None, reason=None): """Closes the WebSocket connection.""" if not self.server_terminated: if not self.stream.closed(): if code is None and reason is not None: code = 1000 # "normal closure" status code if code is None: close_data = b'' else: close_data = struct.pack('>H', code) if reason is not None: close_data += utf8(reason) try: self._write_frame(True, 0x8, close_data) except StreamClosedError: self._abort() self.server_terminated = True if self.client_terminated: if self._waiting is not None: self.stream.io_loop.remove_timeout(self._waiting) self._waiting = None self.stream.close() elif self._waiting is None: # Give the client a few seconds to complete a clean shutdown, # otherwise just close the connection. self._waiting = self.stream.io_loop.add_timeout( self.stream.io_loop.time() + 5, self._abort) @property def ping_interval(self): interval = self.handler.ping_interval if interval is not None: return interval return 0 @property def ping_timeout(self): timeout = self.handler.ping_timeout if timeout is not None: return timeout return max(3 * self.ping_interval, 30) def start_pinging(self): """Start sending periodic pings to keep the connection alive""" if self.ping_interval > 0: self.last_ping = self.last_pong = IOLoop.current().time() self.ping_callback = PeriodicCallback( self.periodic_ping, self.ping_interval * 1000) self.ping_callback.start() def periodic_ping(self): """Send a ping to keep the websocket alive Called periodically if the websocket_ping_interval is set and non-zero. """ if self.stream.closed() and self.ping_callback is not None: self.ping_callback.stop() return # Check for timeout on pong. Make sure that we really have # sent a recent ping in case the machine with both server and # client has been suspended since the last ping. now = IOLoop.current().time() since_last_pong = now - self.last_pong since_last_ping = now - self.last_ping if (since_last_ping < 2 * self.ping_interval and since_last_pong > self.ping_timeout): self.close() return self.write_ping(b'') self.last_ping = now
class AdaptiveCore: """ The core logic for adaptive deployments, with none of the cluster details This class controls our adaptive scaling behavior. It is intended to be used as a super-class or mixin. It expects the following state and methods: **State** plan: set A set of workers that we think should exist. Here and below worker is just a token, often an address or name string requested: set A set of workers that the cluster class has successfully requested from the resource manager. We expect that resource manager to work to make these exist. observed: set A set of workers that have successfully checked in with the scheduler These sets are not necessarily equivalent. Often plan and requested will be very similar (requesting is usually fast) but there may be a large delay between requested and observed (often resource managers don't give us what we want). **Functions** target : -> int Returns the target number of workers that should exist. This is often obtained by querying the scheduler workers_to_close : int -> Set[worker] Given a target number of workers, returns a set of workers that we should close when we're scaling down scale_up : int -> None Scales the cluster up to a target number of workers, presumably changing at least ``plan`` and hopefully eventually also ``requested`` scale_down : Set[worker] -> None Closes the provided set of workers Parameters ---------- minimum: int The minimum number of allowed workers maximum: int | inf The maximum number of allowed workers wait_count: int The number of scale-down requests we should receive before actually scaling down interval: str The amount of time, like ``"1s"`` between checks """ minimum: int maximum: int | float wait_count: int interval: int | float periodic_callback: PeriodicCallback | None plan: set[WorkerState] requested: set[WorkerState] observed: set[WorkerState] close_counts: defaultdict[WorkerState, int] _adapting: bool log: deque[tuple[float, dict]] def __init__( self, minimum: int = 0, maximum: int | float = math.inf, wait_count: int = 3, interval: str | int | float | timedelta | None = "1s", ): if not isinstance(maximum, int) and not math.isinf(maximum): raise TypeError(f"maximum must be int or inf; got {maximum}") self.minimum = minimum self.maximum = maximum self.wait_count = wait_count self.interval = parse_timedelta(interval, "seconds") self.periodic_callback = None def f(): try: self.periodic_callback.start() except AttributeError: pass if self.interval: import weakref self_ref = weakref.ref(self) async def _adapt(): core = self_ref() if core: await core.adapt() self.periodic_callback = PeriodicCallback(_adapt, self.interval * 1000) self.loop.add_callback(f) try: self.plan = set() self.requested = set() self.observed = set() except Exception: pass # internal state self.close_counts = defaultdict(int) self._adapting = False self.log = deque(maxlen=10000) def stop(self) -> None: logger.info("Adaptive stop") if self.periodic_callback: self.periodic_callback.stop() self.periodic_callback = None async def target(self) -> int: """The target number of workers that should exist""" raise NotImplementedError() async def workers_to_close(self, target: int) -> list: """ Give a list of workers to close that brings us down to target workers """ # TODO, improve me with something that thinks about current load return list(self.observed)[target:] async def safe_target(self) -> int: """Used internally, like target, but respects minimum/maximum""" n = await self.target() if n > self.maximum: n = cast(int, self.maximum) if n < self.minimum: n = self.minimum return n async def scale_down(self, n: int) -> None: raise NotImplementedError() async def scale_up(self, workers: Iterable) -> None: raise NotImplementedError() async def recommendations(self, target: int) -> dict: """ Make scale up/down recommendations based on current state and target """ plan = self.plan requested = self.requested observed = self.observed if target == len(plan): self.close_counts.clear() return {"status": "same"} if target > len(plan): self.close_counts.clear() return {"status": "up", "n": target} # target < len(plan) not_yet_arrived = requested - observed to_close = set() if not_yet_arrived: to_close.update(toolz.take(len(plan) - target, not_yet_arrived)) if target < len(plan) - len(to_close): L = await self.workers_to_close(target=target) to_close.update(L) firmly_close = set() for w in to_close: self.close_counts[w] += 1 if self.close_counts[w] >= self.wait_count: firmly_close.add(w) for k in list(self.close_counts): # clear out unseen keys if k in firmly_close or k not in to_close: del self.close_counts[k] if firmly_close: return {"status": "down", "workers": list(firmly_close)} else: return {"status": "same"} async def adapt(self) -> None: """ Check the current state, make recommendations, call scale This is the main event of the system """ if self._adapting: # Semaphore to avoid overlapping adapt calls return self._adapting = True status = None try: target = await self.safe_target() recommendations = await self.recommendations(target) if recommendations["status"] != "same": self.log.append((time(), dict(recommendations))) status = recommendations.pop("status") if status == "same": return if status == "up": await self.scale_up(**recommendations) if status == "down": await self.scale_down(**recommendations) except OSError: if status != "down": logger.error("Adaptive stopping due to error", exc_info=True) self.stop() else: logger.error("Error during adaptive downscaling. Ignoring.", exc_info=True) finally: self._adapting = False def __del__(self): self.stop() @property def loop(self) -> IOLoop: return IOLoop.current()
class DataStore: def __init__(self, config): self.server = config.get_server() self.temp_store_size = config.getint('temperature_store_size', 1200) self.gcode_store_size = config.getint('gcode_store_size', 1000) # Temperature Store Tracking self.last_temps = {} self.gcode_queue = deque(maxlen=self.gcode_store_size) self.temperature_store = {} self.temp_update_cb = PeriodicCallback( self._update_temperature_store, TEMPERATURE_UPDATE_MS) # Register status update event self.server.register_event_handler( "server:status_update", self._set_current_temps) self.server.register_event_handler( "server:gcode_response", self._update_gcode_store) self.server.register_event_handler( "server:klippy_ready", self._init_sensors) # Register endpoints self.server.register_endpoint( "/server/temperature_store", ['GET'], self._handle_temp_store_request) self.server.register_endpoint( "/server/gcode_store", ['GET'], self._handle_gcode_store_request) async def _init_sensors(self): klippy_apis = self.server.lookup_component('klippy_apis') # Fetch sensors try: result = await klippy_apis.query_objects({'heaters': None}) except self.server.error as e: logging.info(f"Error Configuring Sensors: {e}") return sensors = result.get("heaters", {}).get("available_sensors", []) if sensors: # Add Subscription sub = {s: None for s in sensors} try: status = await klippy_apis.subscribe_objects(sub) except self.server.error as e: logging.info(f"Error subscribing to sensors: {e}") return logging.info(f"Configuring available sensors: {sensors}") new_store = {} for sensor in sensors: fields = list(status.get(sensor, {}).keys()) if sensor in self.temperature_store: new_store[sensor] = self.temperature_store[sensor] else: new_store[sensor] = { 'temperatures': deque(maxlen=self.temp_store_size)} for item in ["target", "power", "speed"]: if item in fields: new_store[sensor][f"{item}s"] = deque( maxlen=self.temp_store_size) if sensor not in self.last_temps: self.last_temps[sensor] = (0., 0., 0., 0.) self.temperature_store = new_store # Prune unconfigured sensors in self.last_temps for sensor in list(self.last_temps.keys()): if sensor not in self.temperature_store: del self.last_temps[sensor] # Update initial temperatures self._set_current_temps(status) self.temp_update_cb.start() else: logging.info("No sensors found") self.last_temps = {} self.temperature_store = {} self.temp_update_cb.stop() def _set_current_temps(self, data): for sensor in self.temperature_store: if sensor in data: last_val = self.last_temps[sensor] self.last_temps[sensor] = ( round(data[sensor].get('temperature', last_val[0]), 2), data[sensor].get('target', last_val[1]), data[sensor].get('power', last_val[2]), data[sensor].get('speed', last_val[3])) def _update_temperature_store(self): # XXX - If klippy is not connected, set values to zero # as they are unknown? for sensor, vals in self.last_temps.items(): self.temperature_store[sensor]['temperatures'].append(vals[0]) for val, item in zip(vals[1:], ["targets", "powers", "speeds"]): if item in self.temperature_store[sensor]: self.temperature_store[sensor][item].append(val) async def _handle_temp_store_request(self, web_request): store = {} for name, sensor in self.temperature_store.items(): store[name] = {k: list(v) for k, v in sensor.items()} return store async def close(self): self.temp_update_cb.stop() def _update_gcode_store(self, response): curtime = time.time() self.gcode_queue.append( {'message': response, 'time': curtime, 'type': "response"}) def store_gcode_command(self, script): curtime = time.time() for cmd in script.split('\n'): cmd = cmd.strip() if not cmd: continue self.gcode_queue.append( {'message': script, 'time': curtime, 'type': "command"}) async def _handle_gcode_store_request(self, web_request): count = web_request.get_int("count", None) if count is not None: gc_responses = list(self.gcode_queue)[-count:] else: gc_responses = list(self.gcode_queue) return {'gcode_store': gc_responses}
class Engine(BaseEngine): NAME = 'Redis' OK_RESPONSE = b'OK' def __init__(self, *args, **kwargs): super(Engine, self).__init__(*args, **kwargs) if not self.options.redis_url: self.host = self.options.redis_host self.port = self.options.redis_port self.password = self.options.redis_password self.db = self.options.redis_db else: # according to https://devcenter.heroku.com/articles/redistogo parsed_url = urlparse.urlparse(self.options.redis_url) self.host = parsed_url.hostname self.port = int(parsed_url.port) self.db = 0 self.password = parsed_url.password self.connection_check = PeriodicCallback(self.check_connection, 1000) self._need_reconnect = False self.subscriber = toredis.Client(io_loop=self.io_loop) self.publisher = toredis.Client(io_loop=self.io_loop) self.worker = toredis.Client(io_loop=self.io_loop) self.subscriptions = {} def initialize(self): self.connect() logger.info("Redis engine at {0}:{1} (db {2})".format( self.host, self.port, self.db)) def on_auth(self, res): if res != self.OK_RESPONSE: logger.error("auth failed: {0}".format(res)) def on_subscriber_select(self, res): """ After selecting subscriber database subscribe on channels """ if res != self.OK_RESPONSE: # error returned logger.error("select database failed: {0}".format(res)) self._need_reconnect = True return self.subscriber.subscribe(self.admin_channel_name, callback=self.on_redis_message) self.subscriber.subscribe(self.control_channel_name, callback=self.on_redis_message) for subscription in self.subscriptions.copy(): if subscription not in self.subscriptions: continue self.subscriber.subscribe(subscription, callback=self.on_redis_message) def on_select(self, res): if res != self.OK_RESPONSE: logger.error("select database failed: {0}".format(res)) self._need_reconnect = True def connect(self): """ Connect from scratch """ try: self.subscriber.connect(host=self.host, port=self.port) self.publisher.connect(host=self.host, port=self.port) self.worker.connect(host=self.host, port=self.port) except Exception as e: logger.error("error connecting to Redis server: %s" % (str(e))) else: if self.password: self.subscriber.auth(self.password, callback=self.on_auth) self.publisher.auth(self.password, callback=self.on_auth) self.worker.auth(self.password, callback=self.on_auth) self.subscriber.select(self.db, callback=self.on_subscriber_select) self.publisher.select(self.db, callback=self.on_select) self.worker.select(self.db, callback=self.on_select) self.connection_check.stop() self.connection_check.start() def check_connection(self): conn_statuses = [ self.subscriber.is_connected(), self.publisher.is_connected(), self.worker.is_connected() ] connection_dropped = not all(conn_statuses) if connection_dropped or self._need_reconnect: logger.info('reconnecting to Redis') self._need_reconnect = False self.connect() def _publish(self, channel, message): try: self.publisher.publish(channel, message) except StreamClosedError as e: self._need_reconnect = True logger.error(e) return False else: return True @coroutine def publish_message(self, channel, body, method=BaseEngine.DEFAULT_PUBLISH_METHOD): """ Publish message into channel of stream. """ response = Response() method = method or self.DEFAULT_PUBLISH_METHOD response.method = method response.body = body to_publish = response.as_message() result = self._publish(channel, to_publish) raise Return((result, None)) @coroutine def publish_control_message(self, message): result = self._publish(self.control_channel_name, json_encode(message)) raise Return((result, None)) @coroutine def publish_admin_message(self, message): result = self._publish(self.admin_channel_name, json_encode(message)) raise Return((result, None)) @coroutine def on_redis_message(self, redis_message): """ Got message from Redis, dispatch it into right message handler. """ msg_type = redis_message[0] if six.PY3: msg_type = msg_type.decode() if msg_type != 'message': return channel = redis_message[1] if six.PY3: channel = channel.decode() if channel == self.control_channel_name: yield self.handle_control_message(json_decode(redis_message[2])) elif channel == self.admin_channel_name: yield self.handle_admin_message(json_decode(redis_message[2])) else: yield self.handle_message(channel, redis_message[2]) @coroutine def handle_admin_message(self, message): message = json_encode(message) for uid, connection in six.iteritems( self.application.admin_connections): if uid not in self.application.admin_connections: continue connection.send(message) raise Return((True, None)) @coroutine def handle_control_message(self, message): """ Handle control message. """ app_id = message.get("app_id") method = message.get("method") params = message.get("params") if app_id and app_id == self.application.uid: # application id must be set when we don't want to do # make things twice for the same application. Setting # app_id means that we don't want to process control # message when it is appear in application instance if # application uid matches app_id raise Return((True, None)) func = getattr(self.application, 'handle_%s' % method, None) if not func: raise Return((None, self.application.METHOD_NOT_FOUND)) result, error = yield func(params) raise Return((result, error)) @coroutine def handle_message(self, channel, message_data): if channel not in self.subscriptions: raise Return((True, None)) for uid, client in six.iteritems(self.subscriptions[channel]): if channel in self.subscriptions and uid in self.subscriptions[ channel]: yield client.send(message_data) def subscribe_key(self, subscription_key): self.subscriber.subscribe(subscription_key, callback=self.on_redis_message) def unsubscribe_key(self, subscription_key): self.subscriber.unsubscribe(subscription_key) @coroutine def add_subscription(self, project_id, channel, client): """ Subscribe application on channel if necessary and register client to receive messages from that channel. """ subscription_key = self.get_subscription_key(project_id, channel) self.subscribe_key(subscription_key) if subscription_key not in self.subscriptions: self.subscriptions[subscription_key] = {} self.subscriptions[subscription_key][client.uid] = client raise Return((True, None)) @coroutine def remove_subscription(self, project_id, channel, client): """ Unsubscribe application from channel if necessary and prevent client from receiving messages from that channel. """ subscription_key = self.get_subscription_key(project_id, channel) try: del self.subscriptions[subscription_key][client.uid] except KeyError: pass try: if not self.subscriptions[subscription_key]: self.unsubscribe_key(subscription_key) del self.subscriptions[subscription_key] except KeyError: pass raise Return((True, None)) def get_presence_hash_key(self, project_id, channel): return "%s:presence:hash:%s:%s" % (self.prefix, project_id, channel) def get_presence_set_key(self, project_id, channel): return "%s:presence:set:%s:%s" % (self.prefix, project_id, channel) def get_history_list_key(self, project_id, channel): return "%s:history:list:%s:%s" % (self.prefix, project_id, channel) @coroutine def add_presence(self, project_id, channel, uid, user_info, presence_timeout=None): now = int(time.time()) expire_at = now + (presence_timeout or self.presence_timeout) hash_key = self.get_presence_hash_key(project_id, channel) set_key = self.get_presence_set_key(project_id, channel) try: pipeline = self.worker.pipeline() pipeline.multi() pipeline.zadd(set_key, {uid: expire_at}) pipeline.hset(hash_key, uid, json_encode(user_info)) pipeline.execute() yield Task(pipeline.send) except StreamClosedError as e: raise Return((None, e)) else: raise Return((True, None)) @coroutine def remove_presence(self, project_id, channel, uid): hash_key = self.get_presence_hash_key(project_id, channel) set_key = self.get_presence_set_key(project_id, channel) try: pipeline = self.worker.pipeline() pipeline.hdel(hash_key, uid) pipeline.zrem(set_key, uid) yield Task(pipeline.send) except StreamClosedError as e: raise Return((None, e)) else: raise Return((True, None)) @coroutine def get_presence(self, project_id, channel): now = int(time.time()) hash_key = self.get_presence_hash_key(project_id, channel) set_key = self.get_presence_set_key(project_id, channel) try: expired_keys = yield Task(self.worker.zrangebyscore, set_key, 0, now) if expired_keys: pipeline = self.worker.pipeline() pipeline.zremrangebyscore(set_key, 0, now) pipeline.hdel(hash_key, [x.decode() for x in expired_keys]) yield Task(pipeline.send) data = yield Task(self.worker.hgetall, hash_key) except StreamClosedError as e: raise Return((None, e)) else: raise Return((dict_from_list(data), None)) @coroutine def add_history_message(self, project_id, channel, message, history_size=None, history_expire=0): history_size = history_size or self.history_size history_list_key = self.get_history_list_key(project_id, channel) try: pipeline = self.worker.pipeline() pipeline.lpush(history_list_key, json_encode(message)) pipeline.ltrim(history_list_key, 0, history_size - 1) if history_expire: pipeline.expire(history_list_key, history_expire) else: pipeline.persist(history_list_key) yield Task(pipeline.send) except StreamClosedError as e: raise Return((None, e)) else: raise Return((True, None)) @coroutine def get_history(self, project_id, channel): history_list_key = self.get_history_list_key(project_id, channel) try: data = yield Task(self.worker.lrange, history_list_key, 0, -1) except StreamClosedError as e: raise Return((None, e)) else: raise Return(([json_decode(x.decode()) for x in data], None))
class Reader(Client): r""" Reader provides high-level functionality for building robust NSQ consumers in Python on top of the async module. Reader receives messages over the specified ``topic/channel`` and calls ``message_handler`` for each message (up to ``max_tries``). Multiple readers can be instantiated in a single process (to consume from multiple topics/channels at once). Supports various hooks to modify behavior when heartbeats are received, to temporarily disable the reader, and pre-process/validate messages. When supplied a list of ``nsqlookupd`` addresses, it will periodically poll those addresses to discover new producers of the specified ``topic``. It maintains a sufficient RDY count based on the # of producers and your configured ``max_in_flight``. Handlers should be defined as shown in the examples below. The ``message_handler`` callback function receives a :class:`nsq.Message` object that has instance methods :meth:`nsq.Message.finish`, :meth:`nsq.Message.requeue`, and :meth:`nsq.Message.touch` which can be used to respond to ``nsqd``. As an alternative to explicitly calling these response methods, the handler function can simply return ``True`` to finish the message, or ``False`` to requeue it. If the handler function calls :meth:`nsq.Message.enable_async`, then automatic finish/requeue is disabled, allowing the :class:`nsq.Message` to finish or requeue in a later async callback or context. The handler function may also be a coroutine, in which case Message async handling is enabled automatically, but the coroutine can still return a final value of True/False to automatically finish/requeue the message. After re-queueing a message, the handler will backoff from processing additional messages for an increasing delay (calculated exponentially based on consecutive failures up to ``max_backoff_duration``). Synchronous example:: import nsq def handler(message): print message return True r = nsq.Reader(message_handler=handler, lookupd_http_addresses=['http://127.0.0.1:4161'], topic='nsq_reader', channel='asdf', lookupd_poll_interval=15) nsq.run() Asynchronous example:: import nsq buf = [] def process_message(message): global buf message.enable_async() # cache the message for later processing buf.append(message) if len(buf) >= 3: for msg in buf: print msg msg.finish() buf = [] else: print 'deferring processing' r = nsq.Reader(message_handler=process_message, lookupd_http_addresses=['http://127.0.0.1:4161'], topic='nsq_reader', channel='async', max_in_flight=9) nsq.run() :param message_handler: the callable that will be executed for each message received :param topic: specifies the desired NSQ topic :param channel: specifies the desired NSQ channel :param name: a string that is used for logging messages (defaults to 'topic:channel') :param nsqd_tcp_addresses: a sequence of string addresses of the nsqd instances this reader should connect to :param lookupd_http_addresses: a sequence of string addresses of the nsqlookupd instances this reader should query for producers of the specified topic :param max_tries: the maximum number of attempts the reader will make to process a message after which messages will be automatically discarded :param max_in_flight: the maximum number of messages this reader will pipeline for processing. this value will be divided evenly amongst the configured/discovered nsqd producers :param lookupd_poll_interval: the amount of time in seconds between querying all of the supplied nsqlookupd instances. a random amount of time based on this value will be initially introduced in order to add jitter when multiple readers are running :param lookupd_poll_jitter: The maximum fractional amount of jitter to add to the lookupd poll loop. This helps evenly distribute requests even if multiple consumers restart at the same time. :param lookupd_connect_timeout: the amount of time in seconds to wait for a connection to ``nsqlookupd`` to be established :param lookupd_request_timeout: the amount of time in seconds to wait for a request to ``nsqlookupd`` to complete. :param low_rdy_idle_timeout: the amount of time in seconds to wait for a message from a producer when in a state where RDY counts are re-distributed (ie. max_in_flight < num_producers) :param max_backoff_duration: the maximum time we will allow a backoff state to last in seconds :param \*\*kwargs: passed to :class:`nsq.AsyncConn` initialization """ def __init__( self, topic, channel, message_handler=None, name=None, nsqd_tcp_addresses=None, lookupd_http_addresses=None, max_tries=5, max_in_flight=1, lookupd_poll_interval=60, low_rdy_idle_timeout=10, max_backoff_duration=128, lookupd_poll_jitter=0.3, lookupd_connect_timeout=1, lookupd_request_timeout=2, **kwargs): super(Reader, self).__init__(**kwargs) assert isinstance(topic, string_types) and len(topic) > 0 assert isinstance(channel, string_types) and len(channel) > 0 assert isinstance(max_in_flight, int) and max_in_flight > 0 assert isinstance(max_backoff_duration, (int, float)) and max_backoff_duration > 0 assert isinstance(name, string_types + (None.__class__,)) assert isinstance(lookupd_poll_interval, int) assert isinstance(lookupd_poll_jitter, float) assert isinstance(lookupd_connect_timeout, int) assert isinstance(lookupd_request_timeout, int) assert lookupd_poll_jitter >= 0 and lookupd_poll_jitter <= 1 if nsqd_tcp_addresses: if not isinstance(nsqd_tcp_addresses, (list, set, tuple)): assert isinstance(nsqd_tcp_addresses, string_types) nsqd_tcp_addresses = [nsqd_tcp_addresses] else: nsqd_tcp_addresses = [] if lookupd_http_addresses: if not isinstance(lookupd_http_addresses, (list, set, tuple)): assert isinstance(lookupd_http_addresses, string_types) lookupd_http_addresses = [lookupd_http_addresses] random.shuffle(lookupd_http_addresses) else: lookupd_http_addresses = [] assert nsqd_tcp_addresses or lookupd_http_addresses self.name = name or (topic + ':' + channel) self.message_handler = None if message_handler: self.set_message_handler(message_handler) self.topic = topic self.channel = channel self.nsqd_tcp_addresses = nsqd_tcp_addresses self.lookupd_http_addresses = lookupd_http_addresses self.lookupd_query_index = 0 self.max_tries = max_tries self.max_in_flight = max_in_flight self.low_rdy_idle_timeout = low_rdy_idle_timeout self.total_rdy = 0 self.need_rdy_redistributed = False self.lookupd_poll_interval = lookupd_poll_interval self.lookupd_poll_jitter = lookupd_poll_jitter self.lookupd_connect_timeout = lookupd_connect_timeout self.lookupd_request_timeout = lookupd_request_timeout self.random_rdy_ts = time.time() # Verify keyword arguments valid_args = func_args(AsyncConn.__init__) diff = set(kwargs) - set(valid_args) assert len(diff) == 0, 'Invalid keyword argument(s): %s' % list(diff) self.conn_kwargs = kwargs self.backoff_timer = BackoffTimer(0, max_backoff_duration) self.backoff_block = False self.backoff_block_completed = True self.conns = {} self.connection_attempts = {} self.http_client = tornado.httpclient.AsyncHTTPClient() # will execute when run() is called (for all Reader instances) self.io_loop.add_callback(self._run) self.redist_periodic = None self.query_periodic = None def _run(self): assert self.message_handler, "you must specify the Reader's message_handler" logger.info('[%s] starting reader for %s/%s...', self.name, self.topic, self.channel) for addr in self.nsqd_tcp_addresses: address, port = addr.split(':') self.connect_to_nsqd(address, int(port)) self.redist_periodic = PeriodicCallback( self._redistribute_rdy_state, 5 * 1000, ) self.redist_periodic.start() if not self.lookupd_http_addresses: return # trigger the first lookup query manually self.io_loop.spawn_callback(self.query_lookupd) self.query_periodic = PeriodicCallback( self.query_lookupd, self.lookupd_poll_interval * 1000, ) # randomize the time we start this poll loop so that all # consumers don't query at exactly the same time delay = random.random() * self.lookupd_poll_interval * self.lookupd_poll_jitter self.io_loop.call_later(delay, self.query_periodic.start) def close(self): """ Closes all connections stops all periodic callbacks """ for conn in self.conns.values(): conn.close() self.redist_periodic.stop() if self.query_periodic is not None: self.query_periodic.stop() def set_message_handler(self, message_handler): """ Assigns the callback method to be executed for each message received :param message_handler: a callable that takes a single argument """ assert callable(message_handler), 'message_handler must be callable' self.message_handler = message_handler def _connection_max_in_flight(self): return max(1, self.max_in_flight // max(1, len(self.conns))) def is_starved(self): """ Used to identify when buffered messages should be processed and responded to. When max_in_flight > 1 and you're batching messages together to perform work is isn't possible to just compare the len of your list of buffered messages against your configured max_in_flight (because max_in_flight may not be evenly divisible by the number of producers you're connected to, ie. you might never get that many messages... it's a *max*). Example:: def message_handler(self, nsq_msg, reader): # buffer messages if reader.is_starved(): # perform work reader = nsq.Reader(...) reader.set_message_handler(functools.partial(message_handler, reader=reader)) nsq.run() """ for conn in itervalues(self.conns): if conn.in_flight > 0 and conn.in_flight >= (conn.last_rdy * 0.85): return True return False def _on_message(self, conn, message, **kwargs): try: self._handle_message(conn, message) except Exception: logger.exception('[%s:%s] failed to handle_message() %r', conn.id, self.name, message) def _handle_message(self, conn, message): self._maybe_update_rdy(conn) result = False try: if 0 < self.max_tries < message.attempts: self.giving_up(message) return message.finish() pre_processed_message = self.preprocess_message(message) if not self.validate_message(pre_processed_message): return message.finish() result = self.process_message(message) except Exception: logger.exception('[%s:%s] uncaught exception while handling message %s body:%r', conn.id, self.name, message.id, message.body) if not message.has_responded(): return message.requeue() if result not in (True, False, None): # assume handler returned a Future or Coroutine message.enable_async() fut = tornado.gen.convert_yielded(result) fut.add_done_callback(functools.partial(self._maybe_finish, message)) elif not message.is_async() and not message.has_responded(): assert result is not None, 'ambiguous return value for synchronous mode' if result: return message.finish() return message.requeue() def _maybe_finish(self, message, fut): if not message.has_responded(): try: if fut.result(): message.finish() return except Exception: pass message.requeue() def _maybe_update_rdy(self, conn): if self.backoff_timer.get_interval() or self.max_in_flight == 0: return # Update RDY in 2 cases: # 1. On a new connection or in backoff we start with a tentative RDY # count of 1. After successfully receiving a first message we go to # full throttle. # 2. After a change in connection count or max_in_flight we adjust to the new # connection_max_in_flight. conn_max_in_flight = self._connection_max_in_flight() if conn.rdy == 1 or conn.rdy != conn_max_in_flight: self._send_rdy(conn, conn_max_in_flight) def _finish_backoff_block(self): self.backoff_block = False # we must have raced and received a message out of order that resumed # so just complete the backoff block if not self.backoff_timer.get_interval(): self._complete_backoff_block() return # test the waters after finishing a backoff round # if we have no connections, this will happen when a new connection gets RDY 1 if not self.conns or self.max_in_flight == 0: return conn = random.choice(list(self.conns.values())) logger.info('[%s:%s] testing backoff state with RDY 1', conn.id, self.name) self._send_rdy(conn, 1) # for tests return conn def _on_backoff_resume(self, success, **kwargs): if success: self.backoff_timer.success() elif success is False and not self.backoff_block: self.backoff_timer.failure() self._enter_continue_or_exit_backoff() def _complete_backoff_block(self): self.backoff_block_completed = True rdy = self._connection_max_in_flight() logger.info('[%s] backoff complete, resuming normal operation (%d connections)', self.name, len(self.conns)) for c in self.conns.values(): self._send_rdy(c, rdy) def _enter_continue_or_exit_backoff(self): # Take care of backoff in the appropriate cases. When this # happens, we set a failure on the backoff timer and set the RDY count to zero. # Once the backoff time has expired, we allow *one* of the connections let # a single message through to test the water. This will continue until we # reach no backoff in which case we go back to the normal RDY count. current_backoff_interval = self.backoff_timer.get_interval() # do nothing if self.backoff_block: return # we're out of backoff completely, return to full blast for all conns if not self.backoff_block_completed and not current_backoff_interval: self._complete_backoff_block() return # enter or continue a backoff iteration if current_backoff_interval: self._start_backoff_block() def _start_backoff_block(self): self.backoff_block = True self.backoff_block_completed = False backoff_interval = self.backoff_timer.get_interval() logger.info('[%s] backing off for %0.2f seconds (%d connections)', self.name, backoff_interval, len(self.conns)) for c in self.conns.values(): self._send_rdy(c, 0) self.io_loop.call_later(backoff_interval, self._finish_backoff_block) def _rdy_retry(self, conn, value): conn.rdy_timeout = None self._send_rdy(conn, value) def _send_rdy(self, conn, value): if conn.rdy_timeout: self.io_loop.remove_timeout(conn.rdy_timeout) conn.rdy_timeout = None if value and (self.disabled() or self.max_in_flight == 0): logger.info('[%s:%s] disabled, delaying RDY state change', conn.id, self.name) rdy_retry_callback = functools.partial(self._rdy_retry, conn, value) conn.rdy_timeout = self.io_loop.call_later(15, rdy_retry_callback) return if value > conn.max_rdy_count: value = conn.max_rdy_count new_rdy = max(self.total_rdy - conn.rdy + value, 0) if conn.send_rdy(value): self.total_rdy = new_rdy def connect_to_nsqd(self, host, port): """ Adds a connection to ``nsqd`` at the specified address. :param host: the address to connect to :param port: the port to connect to """ assert isinstance(host, string_types) assert isinstance(port, int) conn = AsyncConn(host, port, **self.conn_kwargs) conn.on('identify', self._on_connection_identify) conn.on('identify_response', self._on_connection_identify_response) conn.on('auth', self._on_connection_auth) conn.on('auth_response', self._on_connection_auth_response) conn.on('error', self._on_connection_error) conn.on('close', self._on_connection_close) conn.on('ready', self._on_connection_ready) conn.on('message', self._on_message) conn.on('heartbeat', self._on_heartbeat) conn.on('backoff', functools.partial(self._on_backoff_resume, success=False)) conn.on('resume', functools.partial(self._on_backoff_resume, success=True)) conn.on('continue', functools.partial(self._on_backoff_resume, success=None)) if conn.id in self.conns: return # only attempt to re-connect once every 10s per destination # this throttles reconnects to failed endpoints now = time.time() last_connect_attempt = self.connection_attempts.get(conn.id) if last_connect_attempt and last_connect_attempt > now - 10: return self.connection_attempts[conn.id] = now logger.info('[%s:%s] connecting to nsqd', conn.id, self.name) conn.connect() return conn def _on_connection_ready(self, conn, **kwargs): conn.send(protocol.subscribe(self.topic, self.channel)) # re-check to make sure another connection didn't beat this one done if conn.id in self.conns: logger.warning( '[%s:%s] connected to NSQ but anothermatching connection already exists', conn.id, self.name) conn.close() return if conn.max_rdy_count < self.max_in_flight: logger.warning( '[%s:%s] max RDY count %d < reader max in flight %d, truncation possible', conn.id, self.name, conn.max_rdy_count, self.max_in_flight) self.conns[conn.id] = conn conn_max_in_flight = self._connection_max_in_flight() for c in self.conns.values(): if c.rdy > conn_max_in_flight: self._send_rdy(c, conn_max_in_flight) # we send an initial RDY of 1 up to our configured max_in_flight # this resolves two cases: # 1. `max_in_flight >= num_conns` ensuring that no connections are ever # *initially* starved since redistribute won't apply # 2. `max_in_flight < num_conns` ensuring that we never exceed max_in_flight # and rely on the fact that redistribute will handle balancing RDY across conns if not self.backoff_timer.get_interval() or len(self.conns) == 1: # only send RDY 1 if we're not in backoff (some other conn # should be testing the waters) # (but always send it if we're the first) self._send_rdy(conn, 1) def _on_connection_close(self, conn, **kwargs): if conn.id in self.conns: del self.conns[conn.id] self.total_rdy = max(self.total_rdy - conn.rdy, 0) logger.warning('[%s:%s] connection closed', conn.id, self.name) if (conn.rdy_timeout or conn.rdy) and \ (len(self.conns) == self.max_in_flight or self.backoff_timer.get_interval()): # we're toggling out of (normal) redistribution cases and this conn # had a RDY count... # # trigger RDY redistribution to make sure this RDY is moved # to a new connection self.need_rdy_redistributed = True if conn.rdy_timeout: self.io_loop.remove_timeout(conn.rdy_timeout) conn.rdy_timeout = None if not self.lookupd_http_addresses: # automatically reconnect to nsqd addresses when not using lookupd logger.info('[%s:%s] attempting to reconnect in 15s', conn.id, self.name) reconnect_callback = functools.partial(self.connect_to_nsqd, host=conn.host, port=conn.port) self.io_loop.call_later(15, reconnect_callback) @tornado.gen.coroutine def query_lookupd(self): """ Trigger a query of the configured ``nsq_lookupd_http_addresses``. """ endpoint = self.lookupd_http_addresses[self.lookupd_query_index] self.lookupd_query_index = (self.lookupd_query_index + 1) % len(self.lookupd_http_addresses) # urlsplit() is faulty if scheme not present if '://' not in endpoint: endpoint = 'http://' + endpoint scheme, netloc, path, query, fragment = urlparse.urlsplit(endpoint) if not path or path == "/": path = "/lookup" params = parse_qs(query) params['topic'] = self.topic query = urlencode(_utf8_params(params), doseq=1) lookupd_url = urlparse.urlunsplit((scheme, netloc, path, query, fragment)) req = tornado.httpclient.HTTPRequest( lookupd_url, method='GET', headers={'Accept': 'application/vnd.nsq; version=1.0'}, connect_timeout=self.lookupd_connect_timeout, request_timeout=self.lookupd_request_timeout) try: response = yield self.http_client.fetch(req) except Exception as e: logger.warning('[%s] lookupd %s query error: %s', self.name, lookupd_url, e) return try: lookup_data = json.loads(response.body.decode("utf8")) except ValueError: logger.warning('[%s] lookupd %s failed to parse JSON: %r', self.name, lookupd_url, response.body) return for producer in lookup_data['producers']: # TODO: this can be dropped for 1.0 address = producer.get('broadcast_address', producer.get('address')) assert address self.connect_to_nsqd(address, producer['tcp_port']) def set_max_in_flight(self, max_in_flight): """Dynamically adjust the reader max_in_flight. Set to 0 to immediately disable a Reader""" assert isinstance(max_in_flight, int) self.max_in_flight = max_in_flight if max_in_flight == 0: # set RDY 0 to all connections for conn in itervalues(self.conns): if conn.rdy > 0: logger.debug('[%s:%s] rdy: %d -> 0', conn.id, self.name, conn.rdy) self._send_rdy(conn, 0) self.total_rdy = 0 else: self.need_rdy_redistributed = True self._redistribute_rdy_state() def _redistribute_rdy_state(self): # We redistribute RDY counts in a few cases: # # 1. our # of connections exceeds our configured max_in_flight # 2. we're in backoff mode (but not in a current backoff block) # 3. something out-of-band has set the need_rdy_redistributed flag (connection closed # that was about to get RDY during backoff) # # At a high level, we're trying to mitigate stalls related to low-volume # producers when we're unable (by configuration or backoff) to provide a RDY count # of (at least) 1 to all of our connections. if not self.conns: return if self.disabled() or self.backoff_block or self.max_in_flight == 0: return if len(self.conns) > self.max_in_flight: self.need_rdy_redistributed = True logger.debug('redistributing RDY state (%d conns > %d max_in_flight)', len(self.conns), self.max_in_flight) backoff_interval = self.backoff_timer.get_interval() if backoff_interval and len(self.conns) > 1: self.need_rdy_redistributed = True logger.debug('redistributing RDY state (%d backoff interval and %d conns > 1)', backoff_interval, len(self.conns)) if self.need_rdy_redistributed: self.need_rdy_redistributed = False # first set RDY 0 to all connections that have not received a message within # a configurable timeframe (low_rdy_idle_timeout). for conn_id, conn in iteritems(self.conns): last_message_duration = time.time() - conn.last_msg_timestamp logger.debug('[%s:%s] rdy: %d (last message received %.02fs)', conn.id, self.name, conn.rdy, last_message_duration) if conn.rdy > 0 and last_message_duration > self.low_rdy_idle_timeout: logger.info('[%s:%s] idle connection, giving up RDY count', conn.id, self.name) self._send_rdy(conn, 0) conns = self.conns.values() in_flight_or_rdy = len([c for c in conns if c.in_flight or c.rdy]) if backoff_interval: available_rdy = max(0, 1 - in_flight_or_rdy) else: available_rdy = max(0, self.max_in_flight - in_flight_or_rdy) # if moving any connections from RDY 0 to non-0 would violate in-flight constraints, # set RDY 0 on some connection with msgs in flight so that a later redistribution # round can proceed and we don't stay pinned to the same connections. # # if nothing's in flight, then we have connections with RDY 1 that are still # waiting to hit the idle timeout, in which case it's ok to do nothing. in_flight = [c for c in conns if c.in_flight] if in_flight and not available_rdy: conn = random.choice(in_flight) logger.info('[%s:%s] too many msgs in flight, giving up RDY count', conn.id, self.name) self._send_rdy(conn, 0) # randomly walk the list of possible connections and send RDY 1 (up to our # calculated "max_in_flight"). We only need to send RDY 1 because in both # cases described above your per connection RDY count would never be higher. # # We also don't attempt to avoid the connections who previously might have had RDY 1 # because it would be overly complicated and not actually worth it (ie. given enough # redistribution rounds it doesn't matter). possible_conns = [c for c in conns if not (c.in_flight or c.rdy)] while possible_conns and available_rdy: available_rdy -= 1 conn = possible_conns.pop(random.randrange(len(possible_conns))) logger.info('[%s:%s] redistributing RDY', conn.id, self.name) self._send_rdy(conn, 1) # for tests return conn # # subclass overwriteable # def process_message(self, message): """ Called when a message is received in order to execute the configured ``message_handler`` This is useful to subclass and override if you want to change how your message handlers are called. :param message: the :class:`nsq.Message` received """ return self.message_handler(message) def giving_up(self, message): """ Called when a message has been received where ``msg.attempts > max_tries`` This is useful to subclass and override to perform a task (such as writing to disk, etc.) :param message: the :class:`nsq.Message` received """ logger.warning('[%s] giving up on message %s after %d tries (max:%d) %r', self.name, message.id, message.attempts, self.max_tries, message.body) def _on_connection_identify_response(self, conn, data, **kwargs): if not hasattr(self, '_disabled_notice'): self._disabled_notice = True def semver(v): def cast(x): try: return int(x) except Exception: return x return [cast(x) for x in v.replace('-', '.').split('.')] if self.disabled.__code__ != Reader.disabled.__code__ and \ semver(data['version']) >= semver('0.3'): warnings.warn('disabled() is deprecated and will be removed in a future release, ' 'use set_max_in_flight(0) instead', DeprecationWarning) return super(Reader, self)._on_connection_identify_response(conn, data, **kwargs) @classmethod def disabled(cls): """ Called as part of RDY handling to identify whether this Reader has been disabled This is useful to subclass and override to examine a file on disk or a key in cache to identify if this reader should pause execution (during a deploy, etc.). Note: deprecated. Use set_max_in_flight(0) """ return False def validate_message(self, message): return True def preprocess_message(self, message): return message
class ApsCapture(object): def __init__(self, capture_interface, control_socket, mkrecv_config_filename, mkrecv_cpu_set, apsuse_cpu_set, sensor_prefix, dada_key): log.info("Building ApsCapture instance with parameters: ({})".format( capture_interface, control_socket, mkrecv_config_filename, mkrecv_cpu_set, apsuse_cpu_set, sensor_prefix, dada_key)) self._capture_interface = capture_interface self._control_socket = control_socket self._mkrecv_config_filename = mkrecv_config_filename self._mkrecv_cpu_set = mkrecv_cpu_set self._apsuse_cpu_set = apsuse_cpu_set self._sensor_prefix = sensor_prefix self._dada_input_key = dada_key self._mkrecv_proc = None self._apsuse_proc = None self._dada_db_proc = None self._ingress_buffer_monitor = None self._internal_beam_mapping = {} self._sensors = [] self._capturing = False self.ioloop = IOLoop.current() self.setup_sensors() def add_sensor(self, sensor): sensor.name = "{}-{}".format(self._sensor_prefix, sensor.name) self._sensors.append(sensor) def setup_sensors(self): self._config_sensor = Sensor.string( "configuration", description="The current configuration of the capture instance", default="", initial_status=Sensor.UNKNOWN) self.add_sensor(self._config_sensor) self._mkrecv_header_sensor = Sensor.string( "mkrecv-capture-header", description= "The MKRECV/DADA header used for configuring capture with MKRECV", default="", initial_status=Sensor.UNKNOWN) self.add_sensor(self._mkrecv_header_sensor) self._apsuse_args_sensor = Sensor.string( "apsuse-arguments", description="The command line arguments used to invoke apsuse", default="", initial_status=Sensor.UNKNOWN) self.add_sensor(self._apsuse_args_sensor) self._mkrecv_heap_loss = Sensor.float( "fbf-heap-loss", description=("The percentage of FBFUSE heaps lost " "(within MKRECV statistics window)"), default=0.0, initial_status=Sensor.UNKNOWN, unit="%") self.add_sensor(self._mkrecv_heap_loss) self._ingress_buffer_percentage = Sensor.float( "ingress-buffer-fill-level", description=("The percentage fill level for the capture" "buffer between MKRECV and APSUSE"), default=0.0, initial_status=Sensor.UNKNOWN, unit="%") self.add_sensor(self._ingress_buffer_percentage) @coroutine def _start_db(self, key, block_size, nblocks, timeout=100.0): log.debug(("Building DADA buffer: key={}, block_size={}, " "nblocks={}").format(key, block_size, nblocks)) cmdline = map(str, [ "dada_db", "-k", key, "-b", block_size, "-n", nblocks, "-l", "-p", "-w" ]) self._dada_db_proc = Popen(cmdline, stdout=PIPE, stderr=PIPE, shell=False, close_fds=True) start = time.time() while psutil.virtual_memory().cached < block_size * nblocks: log.info("Cached: {} bytes, require {} bytes".format( psutil.virtual_memory().cached, block_size * nblocks)) yield sleep(1.0) if time.time() - start > timeout: raise Exception( "Caching of DADA buffer took longer than {} seconds". format(timeout)) log.info("Took {} seconds to allocate {}x{} GB DADA buffer".format( time.time() - start, block_size / 1e9, nblocks)) def _stop_db(self): log.debug("Destroying DADA buffer") if self._dada_db_proc is not None: self._dada_db_proc.terminate() self._dada_db_proc.wait() self._dada_db_proc = None @coroutine def capture_start(self, config): log.info("Preparing apsuse capture instance") log.info("Config: {}".format(config)) nbeams = len(config['beam-ids']) npartitions = config['nchans'] / config['nchans-per-heap'] # desired beams here is a list of beam IDs, e.g. [1,2,3,4,5] heap_group_size = config['heap-size'] * nbeams * npartitions #heap_group_duration = (heap_group_size / config['nchans-per-heap']) * config['sampling-interval'] #optimal_heap_groups = int(OPTIMAL_BLOCK_LENGTH / heap_group_duration) #if (optimal_heap_groups * heap_group_size) > MAX_DADA_BLOCK_SIZE: ngroups_data = int(MAX_DADA_BLOCK_SIZE / heap_group_size) #else: # ngroups_data = optimal_heap_groups # Move to power of 2 heap groups (not necessary, but helpful) ngroups_data = 2**((ngroups_data - 1).bit_length()) # Make sure at least 8 groups used ngroups_data = max(ngroups_data, 8) # Make DADA buffer and start watchers log.info("Creating capture buffer") capture_block_size = ngroups_data * heap_group_size if (capture_block_size * OPTIMAL_CAPTURE_BLOCKS > AVAILABLE_CAPTURE_MEMORY): capture_block_count = int(AVAILABLE_CAPTURE_MEMORY / capture_block_size) if capture_block_count < 3: raise Exception("Cannot allocate more than 2 capture blocks") else: capture_block_count = OPTIMAL_CAPTURE_BLOCKS log.debug("Creating dada buffer for input with key '{}'".format( "%s" % self._dada_input_key)) yield self._start_db(self._dada_input_key, capture_block_size, capture_block_count) log.info("Capture buffer ready") self._config_sensor.set_value(json.dumps(config)) idx = 0 for beam in config['beam-ids']: self._internal_beam_mapping[beam] = idx idx += 1 # Start APSUSE processing code apsuse_cmdline = [ "taskset", "-c", self._apsuse_cpu_set, "apsuse", "--input_key", self._dada_input_key, "--ngroups", ngroups_data, "--nbeams", nbeams, "--nchannels", config['nchans-per-heap'], "--nsamples", config['heap-size'] / config['nchans-per-heap'], "--nfreq", npartitions, "--size", int(config['filesize']), "--socket", self._control_socket, "--dir", config["base-output-dir"], "--log_level", "info" ] log.info("Starting APSUSE") log.debug(" ".join(map(str, apsuse_cmdline))) self._apsuse_proc = ManagedProcess(apsuse_cmdline, stdout_handler=log.debug, stderr_handler=log.error) self._apsuse_args_sensor.set_value(" ".join(map(str, apsuse_cmdline))) yield sleep(5) def make_beam_list(indices): spec = "" for a, b in itertools.groupby(enumerate(indices), lambda pair: pair[1] - pair[0]): if spec != "": spec += "," b = list(b) p, q = b[0][1], b[-1][1] if p == q: spec += "{}".format(p) else: spec += "{}:{}".format(p, q + 1) return spec # Start MKRECV capture code mkrecv_config = { 'dada_mode': 4, 'dada_key': self._dada_input_key, 'bandwidth': config['bandwidth'], 'centre_frequency': config['centre-frequency'], 'nchannels': config["nchans"], 'sampling_interval': config['sampling-interval'], 'sync_epoch': config['sync-epoch'], 'sample_clock': config['sample-clock'], 'mcast_sources': ",".join(config['mcast-groups']), 'nthreads': len(config['mcast-groups']) + 1, 'mcast_port': str(config['mcast-port']), 'interface': self._capture_interface, 'timestamp_step': config['idx1-step'], 'timestamp_modulus': 1, 'beam_ids_csv': make_beam_list(config['stream-indices']), 'freq_ids_csv': "0:{}:{}".format(config['nchans'], config['nchans-per-heap']), 'heap_size': config['heap-size'] } mkrecv_header = make_mkrecv_header( mkrecv_config, outfile=self._mkrecv_config_filename) log.info("Determined MKRECV configuration:\n{}".format(mkrecv_header)) self._mkrecv_header_sensor.set_value(mkrecv_header) def update_heap_loss_sensor(curr, total, avg, window): self._mkrecv_heap_loss.set_value(100.0 - avg) mkrecv_sensor_updater = MkrecvStdoutHandler( callback=update_heap_loss_sensor) def mkrecv_aggregated_output_handler(line): log.debug(line) mkrecv_sensor_updater(line) log.info("Starting MKRECV") self._mkrecv_proc = ManagedProcess( [ "taskset", "-c", self._mkrecv_cpu_set, "mkrecv_rnt", "--header", self._mkrecv_config_filename, "--quiet" ], stdout_handler=mkrecv_aggregated_output_handler, stderr_handler=log.error) yield sleep(5) def exit_check_callback(): if not self._mkrecv_proc.is_alive(): log.error("mkrecv_nt exited unexpectedly") self.ioloop.add_callback(self.capture_stop) elif not self._apsuse_proc.is_alive(): log.error("apsuse pipeline exited unexpectedly") self.ioloop.add_callback(self.capture_stop) self._capture_monitor.stop() self._capture_monitor = PeriodicCallback(exit_check_callback, 1000) self._capture_monitor.start() self._ingress_buffer_monitor = DbMonitor( self._dada_input_key, callback=lambda params: self._ingress_buffer_percentage.set_value( params["fraction-full"])) self._ingress_buffer_monitor.start() self._capturing = True log.info("Successfully started capture pipeline") def target_start(self, beam_info, output_dir): # Send target information to apsuse pipeline # and trigger file writing # First build message containing beam information # in JSON form: # # { # "command":"start", # "beam_parameters": [ # {id: "cfbf00000", name: "PSRJ1823+3410", "ra": "00:00:00.00", "dec": "00:00:00.00"}, # {id: "cfbf00002", name: "SBGS0000", "ra": "00:00:00.00", "dec": "00:00:00.00"}, # {id: "cfbf00006", name: "SBGS0000", "ra": "00:00:00.00", "dec": "00:00:00.00"}, # {id: "cfbf00008", name: "SBGS0000", "ra": "00:00:00.00", "dec": "00:00:00.00"}, # {id: "cfbf00010", name: "SBGS0000", "ra": "00:00:00.00", "dec": "00:00:00.00"} # ] # } # # Here the "idx" parameter refers to the internal index of the beam, e.g. if the # apsuse executable is handling 6 beams these are numbered 0-5 regardless of their # global index. It is thus necessary to track the mapping between internal and # external indices for these beams. # log.info("Target start on capture instance") beam_params = [] message_dict = { "command": "start", "directory": output_dir, "beam_parameters": beam_params } log.info("Parsing beam information") for beam, target_str in beam_info.items(): if beam in self._internal_beam_mapping: idx = self._internal_beam_mapping[beam] target = Target(target_str) ra, dec = map(str, target.radec()) log.info( "IDX: {}, name: {}, ra: {}, dec: {}, source: {}".format( idx, beam, ra, dec, target.name)) beam_params.append({ "idx": idx, "name": beam, "source": target.name, "ra": ra, "dec": dec }) log.debug("Connecting to apsuse instance via socket") client = UDSClient(self._control_socket) log.debug("Sending message: {}".format(json.dumps(message_dict))) client.send(json.dumps(message_dict)) response_str = client.recv(timeout=3) try: response = json.loads(response_str)["response"] except Exception: log.exception( "Unable to parse JSON returned from apsuse application") else: log.debug("Response: {}".format(response_str)) if response != "success": raise Exception("Failed to start APSUSE recording") finally: client.close() log.debug("Closed socket connection") def target_stop(self): # Trigger end of file writing # First build JSON message to trigger end. # Message has the form: # { # "command": "stop" # } log.info("Target stop request on capture instance") message = {"command": "stop"} log.debug("Connecting to apsuse instance via socket") client = UDSClient(self._control_socket) log.debug("Sending message: {}".format(json.dumps(message))) client.send(json.dumps(message)) response_str = client.recv(timeout=3) try: response = json.loads(response_str)["response"] except Exception: log.exception( "Unable to parse JSON returned from apsuse application") else: log.debug("Response: {}".format(response_str)) if response != "success": raise Exception("Failed to stop APSUSE recording") finally: client.close() log.debug("Closed socket connection") @coroutine def capture_stop(self): log.info("Capture stop request on capture instance") self._capturing = False self._internal_beam_mapping = {} log.info("Stopping capture monitors") self._capture_monitor.stop() self._ingress_buffer_monitor.stop() log.info("Stopping MKRECV instance") self._mkrecv_proc.terminate() log.info("Stopping PSRDADA_CPP instance") self._apsuse_proc.terminate() log.info("Destroying DADA buffers") self._stop_db()
class UpdateManager: def __init__(self, config): self.server = config.get_server() self.config = config self.config.read_supplemental_config(SUPPLEMENTAL_CFG_PATH) self.repo_debug = config.getboolean('enable_repo_debug', False) auto_refresh_enabled = config.getboolean('enable_auto_refresh', False) self.distro = config.get('distro', "debian").lower() if self.distro not in SUPPORTED_DISTROS: raise config.error(f"Unsupported distro: {self.distro}") if self.repo_debug: logging.warn("UPDATE MANAGER: REPO DEBUG ENABLED") env = sys.executable mooncfg = self.config[f"update_manager static {self.distro} moonraker"] self.updaters = { "system": PackageUpdater(self), "moonraker": GitUpdater(self, mooncfg, MOONRAKER_PATH, env) } self.current_update = None # TODO: Check for client config in [update_manager]. This is # deprecated and will be removed. client_repo = config.get("client_repo", None) if client_repo is not None: client_path = config.get("client_path") name = client_repo.split("/")[-1] self.updaters[name] = WebUpdater(self, { 'repo': client_repo, 'path': client_path }) client_sections = self.config.get_prefix_sections( "update_manager client") for section in client_sections: cfg = self.config[section] name = section.split()[-1] if name in self.updaters: raise config.error("Client repo named %s already added" % (name, )) client_type = cfg.get("type") if client_type == "git_repo": self.updaters[name] = GitUpdater(self, cfg) elif client_type == "web": self.updaters[name] = WebUpdater(self, cfg) else: raise config.error("Invalid type '%s' for section [%s]" % (client_type, section)) # GitHub API Rate Limit Tracking self.gh_rate_limit = None self.gh_limit_remaining = None self.gh_limit_reset_time = None self.gh_init_evt = Event() self.cmd_request_lock = Lock() self.is_refreshing = False # Auto Status Refresh self.last_auto_update_time = 0 self.refresh_cb = None if auto_refresh_enabled: self.refresh_cb = PeriodicCallback(self._handle_auto_refresh, UPDATE_REFRESH_INTERVAL_MS) self.refresh_cb.start() AsyncHTTPClient.configure(None, defaults=dict(user_agent="Moonraker")) self.http_client = AsyncHTTPClient() self.server.register_endpoint("/machine/update/moonraker", ["POST"], self._handle_update_request) self.server.register_endpoint("/machine/update/klipper", ["POST"], self._handle_update_request) self.server.register_endpoint("/machine/update/system", ["POST"], self._handle_update_request) self.server.register_endpoint("/machine/update/client", ["POST"], self._handle_update_request) self.server.register_endpoint("/machine/update/status", ["GET"], self._handle_status_request) # Register Ready Event self.server.register_event_handler("server:klippy_identified", self._set_klipper_repo) # Initialize GitHub API Rate Limits and configured updaters IOLoop.current().spawn_callback(self._initalize_updaters, list(self.updaters.values())) async def _initalize_updaters(self, initial_updaters): self.is_refreshing = True await self._init_api_rate_limit() for updater in initial_updaters: if isinstance(updater, PackageUpdater): ret = updater.refresh(False) else: ret = updater.refresh() if asyncio.iscoroutine(ret): await ret self.is_refreshing = False async def _set_klipper_repo(self): kinfo = self.server.get_klippy_info() if not kinfo: logging.info("No valid klippy info received") return kpath = kinfo['klipper_path'] env = kinfo['python_path'] kupdater = self.updaters.get('klipper', None) if kupdater is not None and kupdater.repo_path == kpath and \ kupdater.env == env: # Current Klipper Updater is valid return kcfg = self.config[f"update_manager static {self.distro} klipper"] self.updaters['klipper'] = GitUpdater(self, kcfg, kpath, env) await self.updaters['klipper'].refresh() async def _check_klippy_printing(self): klippy_apis = self.server.lookup_plugin('klippy_apis') result = await klippy_apis.query_objects({'print_stats': None}, default={}) pstate = result.get('print_stats', {}).get('state', "") return pstate.lower() == "printing" async def _handle_auto_refresh(self): if await self._check_klippy_printing(): # Don't Refresh during a print logging.info("Klippy is printing, auto refresh aborted") return cur_time = time.time() cur_hour = time.localtime(cur_time).tm_hour time_diff = cur_time - self.last_auto_update_time # Update packages if it has been more than 12 hours # and the local time is between 12AM and 5AM if time_diff < MIN_REFRESH_TIME or cur_hour >= MAX_PKG_UPDATE_HOUR: # Not within the update time window return self.last_auto_update_time = cur_time vinfo = {} need_refresh_all = not self.is_refreshing async with self.cmd_request_lock: self.is_refreshing = True try: for name, updater in list(self.updaters.items()): if need_refresh_all: ret = updater.refresh() if asyncio.iscoroutine(ret): await ret if hasattr(updater, "get_update_status"): vinfo[name] = updater.get_update_status() except Exception: logging.exception("Unable to Refresh Status") return finally: self.is_refreshing = False uinfo = { 'version_info': vinfo, 'github_rate_limit': self.gh_rate_limit, 'github_requests_remaining': self.gh_limit_remaining, 'github_limit_reset_time': self.gh_limit_reset_time, 'busy': self.current_update is not None } self.server.send_event("update_manager:update_refreshed", uinfo) async def _handle_update_request(self, web_request): if await self._check_klippy_printing(): raise self.server.error("Update Refused: Klippy is printing") app = web_request.get_endpoint().split("/")[-1] if app == "client": app = web_request.get('name') inc_deps = web_request.get_boolean('include_deps', False) if self.current_update is not None and \ self.current_update[0] == app: return f"Object {app} is currently being updated" updater = self.updaters.get(app, None) if updater is None: raise self.server.error(f"Updater {app} not available") async with self.cmd_request_lock: self.current_update = (app, id(web_request)) try: await updater.update(inc_deps) except Exception as e: self.notify_update_response(f"Error updating {app}") self.notify_update_response(str(e), is_complete=True) raise finally: self.current_update = None return "ok" async def _handle_status_request(self, web_request): check_refresh = web_request.get_boolean('refresh', False) # Don't refresh if a print is currently in progress or # if an update is in progress. Just return the current # state if self.current_update is not None or \ await self._check_klippy_printing(): check_refresh = False need_refresh = False if check_refresh: # If there is an outstanding request processing a # refresh, we don't need to do it again. need_refresh = not self.is_refreshing await self.cmd_request_lock.acquire() self.is_refreshing = True vinfo = {} try: for name, updater in list(self.updaters.items()): await updater.check_initialized(120.) if need_refresh: ret = updater.refresh() if asyncio.iscoroutine(ret): await ret if hasattr(updater, "get_update_status"): vinfo[name] = updater.get_update_status() except Exception: raise finally: if check_refresh: self.is_refreshing = False self.cmd_request_lock.release() return { 'version_info': vinfo, 'github_rate_limit': self.gh_rate_limit, 'github_requests_remaining': self.gh_limit_remaining, 'github_limit_reset_time': self.gh_limit_reset_time, 'busy': self.current_update is not None } async def execute_cmd(self, cmd, timeout=10., notify=False, retries=1): shell_command = self.server.lookup_plugin('shell_command') cb = self.notify_update_response if notify else None scmd = shell_command.build_shell_command(cmd, callback=cb) while retries: if await scmd.run(timeout=timeout, verbose=notify): break retries -= 1 if not retries: raise self.server.error("Shell Command Error") async def execute_cmd_with_response(self, cmd, timeout=10.): shell_command = self.server.lookup_plugin('shell_command') scmd = shell_command.build_shell_command(cmd, None) result = await scmd.run_with_response(timeout, retries=5) if result is None: raise self.server.error(f"Error Running Command: {cmd}") return result async def _init_api_rate_limit(self): url = "https://api.github.com/rate_limit" while 1: try: resp = await self.github_api_request(url, is_init=True) core = resp['resources']['core'] self.gh_rate_limit = core['limit'] self.gh_limit_remaining = core['remaining'] self.gh_limit_reset_time = core['reset'] except Exception: logging.exception("Error Initializing GitHub API Rate Limit") await tornado.gen.sleep(30.) else: reset_time = time.ctime(self.gh_limit_reset_time) logging.info( "GitHub API Rate Limit Initialized\n" f"Rate Limit: {self.gh_rate_limit}\n" f"Rate Limit Remaining: {self.gh_limit_remaining}\n" f"Rate Limit Reset Time: {reset_time}, " f"Seconds Since Epoch: {self.gh_limit_reset_time}") break self.gh_init_evt.set() async def github_api_request(self, url, etag=None, is_init=False): if not is_init: timeout = time.time() + 30. try: await self.gh_init_evt.wait(timeout) except Exception: raise self.server.error("Timeout while waiting for GitHub " "API Rate Limit initialization") if self.gh_limit_remaining == 0: curtime = time.time() if curtime < self.gh_limit_reset_time: raise self.server.error( f"GitHub Rate Limit Reached\nRequest: {url}\n" f"Limit Reset Time: {time.ctime(self.gh_limit_remaining)}") headers = {"Accept": "application/vnd.github.v3+json"} if etag is not None: headers['If-None-Match'] = etag retries = 5 while retries: try: timeout = time.time() + 10. fut = self.http_client.fetch(url, headers=headers, connect_timeout=5., request_timeout=5., raise_error=False) resp = await tornado.gen.with_timeout(timeout, fut) except Exception: retries -= 1 msg = f"Error Processing GitHub API request: {url}" if not retries: raise self.server.error(msg) logging.exception(msg) await tornado.gen.sleep(1.) continue etag = resp.headers.get('etag', None) if etag is not None: if etag[:2] == "W/": etag = etag[2:] logging.info("GitHub API Request Processed\n" f"URL: {url}\n" f"Response Code: {resp.code}\n" f"Response Reason: {resp.reason}\n" f"ETag: {etag}") if resp.code == 403: raise self.server.error( f"Forbidden GitHub Request: {resp.reason}") elif resp.code == 304: logging.info(f"Github Request not Modified: {url}") return None if resp.code != 200: retries -= 1 if not retries: raise self.server.error( f"Github Request failed: {resp.code} {resp.reason}") logging.info( f"Github request error, {retries} retries remaining") await tornado.gen.sleep(1.) continue # Update rate limit on return success if 'X-Ratelimit-Limit' in resp.headers and not is_init: self.gh_rate_limit = int(resp.headers['X-Ratelimit-Limit']) self.gh_limit_remaining = int( resp.headers['X-Ratelimit-Remaining']) self.gh_limit_reset_time = float( resp.headers['X-Ratelimit-Reset']) decoded = json.loads(resp.body) decoded['etag'] = etag return decoded async def http_download_request(self, url): retries = 5 while retries: try: timeout = time.time() + 130. fut = self.http_client.fetch( url, headers={"Accept": "application/zip"}, connect_timeout=5., request_timeout=120.) resp = await tornado.gen.with_timeout(timeout, fut) except Exception: retries -= 1 logging.exception("Error Processing Download") if not retries: raise await tornado.gen.sleep(1.) continue return resp.body def notify_update_response(self, resp, is_complete=False): resp = resp.strip() if isinstance(resp, bytes): resp = resp.decode() notification = { 'message': resp, 'application': None, 'proc_id': None, 'complete': is_complete } if self.current_update is not None: notification['application'] = self.current_update[0] notification['proc_id'] = self.current_update[1] self.server.send_event("update_manager:update_response", notification) def close(self): self.http_client.close() if self.refresh_cb is not None: self.refresh_cb.stop()
class BlockingPool(object): """A connection pool that manages blocking PostgreSQL connections and cursors. :param min_conn: The minimum amount of connections that is created when a connection pool is created. :param max_conn: The maximum amount of connections the connection pool can have. If the amount of connections exceeds the limit a ``PoolError`` exception is raised. :param cleanup_timeout: Time in seconds between pool cleanups. Connections will be closed until there are ``min_conn`` left. :param database: The database name :param user: User name used to authenticate :param password: Password used to authenticate :param connection_factory: Using the connection_factory parameter a different class or connections factory can be specified. It should be a callable object taking a dsn argument. """ def __init__(self, min_conn=1, max_conn=20, cleanup_timeout=10, *args, **kwargs): self.min_conn = min_conn self.max_conn = max_conn self.closed = False self._args = args self._kwargs = kwargs self._pool = [] for i in range(self.min_conn): self._new_conn() # Create a periodic callback that tries to close inactive connections if cleanup_timeout > 0: self._cleaner = PeriodicCallback(self._clean_pool, cleanup_timeout * 1000) self._cleaner.start() def _new_conn(self): """Create a new connection. """ if len(self._pool) > self.max_conn: raise PoolError('connection pool exhausted') conn = psycopg2.connect(*self._args, **self._kwargs) self._pool.append(conn) return conn def _get_free_conn(self): """Look for a free connection and return it. `None` is returned when no free connection can be found. """ if self.closed: raise PoolError('connection pool is closed') for conn in self._pool: if conn.status == STATUS_READY: return conn return None def get_connection(self): """Get a connection from the pool. If there's no free connection available, a new connection will be created. """ connection = self._get_free_conn() if not connection: connection = self._new_conn() return connection def _clean_pool(self): """Close a number of inactive connections when the number of connections in the pool exceeds the number in `min_conn`. """ if self.closed: raise PoolError('connection pool is closed') if len(self._pool) > self.min_conn: conns = len(self._pool) - self.min_conn for conn in self._pool[:]: if conn.status == STATUS_READY: conn.close() conns -= 1 self._pool.remove(conn) if not conns: break def close(self): """Close all open connections in the pool. """ if self.closed: raise PoolError('connection pool is closed') for conn in self._pool: if not conn.closed: conn.close() self._cleaner.stop() self._pool = [] self.closed = True
class AsyncPool(object): """A connection pool that manages asynchronous PostgreSQL connections and cursors. :param min_conn: The minimum amount of connections that is created when a connection pool is created. :param max_conn: The maximum amount of connections the connection pool can have. If the amount of connections exceeds the limit a ``PoolError`` exception is raised. :param cleanup_timeout: Time in seconds between pool cleanups. Connections will be closed until there are ``min_conn`` left. :param ioloop: An instance of Tornado's IOLoop. :param host: The database host address (defaults to UNIX socket if not provided) :param port: The database host port (defaults to 5432 if not provided) :param database: The database name :param user: User name used to authenticate :param password: Password used to authenticate :param connection_factory: Using the connection_factory parameter a different class or connections factory can be specified. It should be a callable object taking a dsn argument. """ def __init__(self, min_conn=1, max_conn=20, cleanup_timeout=10, ioloop=None, *args, **kwargs): self.min_conn = min_conn self.max_conn = max_conn self.closed = False self._ioloop = ioloop or IOLoop.instance() self._args = args self._kwargs = kwargs self._last_reconnect = 0 self._pool = [] for i in range(self.min_conn): self._new_conn() self._last_reconnect = time.time() # Create a periodic callback that tries to close inactive connections if cleanup_timeout > 0: self._cleaner = PeriodicCallback(self._clean_pool, cleanup_timeout * 1000) self._cleaner.start() def _new_conn(self, callback=None, callback_args=[]): """Create a new connection. :param callback_args: Parameters for the callback - connection will be appended to the parameters """ if len(self._pool) > self.max_conn: self._clean_pool() if len(self._pool) > self.max_conn: raise PoolError('connection pool exhausted') timeout = self._last_reconnect + .25 # 1/4 second delay between reconnection timenow = time.time() if timenow > timeout or len(self._pool) <= self.min_conn: self._last_reconnect = timenow conn = AsyncConnection(self._ioloop) callbacks = [partial(self._pool.append, conn)] # add new connection to the pool if callback: callbacks.append(partial(callback, *(callback_args + [conn]))) conn.open(callbacks, *self._args, **self._kwargs) else: # recursive timeout call, retaining the parameters self._ioloop.add_timeout( timeout, partial(self._new_conn, callback, callback_args)) def _get_free_conn(self): """Look for a free connection and return it. `None` is returned when no free connection can be found. """ if self.closed: raise PoolError('connection pool is closed') for conn in self._pool: if not conn.isexecuting(): return conn return None def get_connection(self, callback=None, callback_args=[]): """Get a connection, trying available ones, and if not available - create a new one; Afterwards, the callback will be called """ connection = self._get_free_conn() if connection is None: self._new_conn(callback, callback_args) else: callback(*(callback_args + [connection])) def new_cursor(self, function, function_args=(), callback=None, cursor_kwargs={}, connection=None, transaction=False): """Create a new cursor. If there's no connection available, a new connection will be created and `new_cursor` will be called again after the connection has been made. :param function: ``execute``, ``executemany`` or ``callproc``. :param function_args: A tuple with the arguments for the specified function. :param callback: A callable that is executed once the operation is done. :param cursor_kwargs: A dictionary with Psycopg's `connection.cursor`_ arguments. :param connection: An ``AsyncConnection`` connection. Optional. .. _connection.cursor: http://initd.org/psycopg/docs/connection.html#connection.cursor """ if connection is not None: try: connection.cursor(function, function_args, callback, cursor_kwargs) return except (DatabaseError, InterfaceError): # Recover from lost connection logging.warning('Requested connection was closed') self._pool.remove(connection) # if no connection, or if exception caught if not transaction: self.get_connection(callback=self.new_cursor, callback_args=[ function, function_args, callback, cursor_kwargs ]) else: raise TransactionError def _clean_pool(self): """Close a number of inactive connections when the number of connections in the pool exceeds the number in `min_conn`. """ if self.closed: raise PoolError('connection pool is closed') if len(self._pool) > self.min_conn: conns = len(self._pool) - self.min_conn for conn in self._pool[:]: if not conn.isexecuting(): conn.close() conns -= 1 self._pool.remove(conn) if not conns: break def close(self): """Close all open connections in the pool. """ if self.closed: raise PoolError('connection pool is closed') for conn in self._pool: if not conn.closed: conn.close() self._cleaner.stop() self._pool = [] self.closed = True
class Events(threading.Thread): events_enable_interval = 5000 def __init__(self, capp, db=None, persistent=False, enable_events=True, io_loop=None, state_save_interval=0, **kwargs): threading.Thread.__init__(self) self.daemon = True self.io_loop = io_loop or IOLoop.instance() self.capp = capp self.db = db self.persistent = persistent self.enable_events = enable_events self.state = None self.state_save_timer = None if self.persistent: logger.debug("Loading state from '%s'...", self.db) state = shelve.open(self.db) if state: self.state = state['events'] state.close() if state_save_interval: self.state_save_timer = PeriodicCallback( self.save_state, state_save_interval) if not self.state: self.state = EventsState(**kwargs) self.timer = PeriodicCallback(self.on_enable_events, self.events_enable_interval) def start(self): threading.Thread.start(self) if self.enable_events: logger.debug("Starting enable events timer...") self.timer.start() if self.state_save_timer: logger.debug("Starting state save timer...") self.state_save_timer.start() def stop(self): if self.enable_events: logger.debug("Stopping enable events timer...") self.timer.stop() if self.state_save_timer: logger.debug("Stopping state save timer...") self.state_save_timer.stop() if self.persistent: self.save_state() def run(self): try_interval = 1 while True: try: try_interval *= 2 with self.capp.connection() as conn: recv = EventReceiver(conn, handlers={"*": self.on_event}, app=self.capp) try_interval = 1 logger.debug("Capturing events...") recv.capture(limit=None, timeout=None, wakeup=True) except (KeyboardInterrupt, SystemExit): try: import _thread as thread except ImportError: import thread thread.interrupt_main() except Exception as e: logger.error( "Failed to capture events: '%s', " "trying again in %s seconds.", e, try_interval) logger.debug(e, exc_info=True) time.sleep(try_interval) def save_state(self): logger.debug("Saving state to '%s'...", self.db) state = shelve.open(self.db) state['events'] = self.state state.close() def on_enable_events(self): # Periodically enable events for workers # launched after flower self.io_loop.run_in_executor(None, self.capp.control.enable_events) def on_event(self, event): # Call EventsState.event in ioloop thread to avoid synchronization self.io_loop.add_callback(partial(self.state.event, event))