Beispiel #1
0
def start_game():
    ''' Main entry point for the application '''
    cache_actions()
    sockets = netutil.bind_sockets(8888)
    #if process.task_id() == None:
    #    tornado.process.fork_processes(-1, max_restarts = 10)
    server = HTTPServer(application)
    server.add_sockets(sockets)
    io_loop = IOLoop.instance()
    session_manager = SessionManager.Instance()
    if process.task_id() == None:
        scoring = PeriodicCallback(scoring_round, application.settings['ticks'], io_loop = io_loop)
        session_clean_up = PeriodicCallback(session_manager.clean_up, application.settings['clean_up_timeout'], io_loop = io_loop)
        scoring.start()
        session_clean_up.start()
    try:
        for count in range(3, 0, -1):
            logging.info("The game will begin in ... %d" % (count,))
            sleep(1)
        logging.info("Good hunting!")
        io_loop.start()
    except KeyboardInterrupt:
        if process.task_id() == 0:
            print '\r[!] Shutdown Everything!'
            session_clean_up.stop()
            io_loop.stop()
class WSHandler(tornado.websocket.WebSocketHandler):
	def check_origin(self, origin):
		return True
    
	def open(self):
		with q_live.mutex:
			q_live.queue.clear()
		self.callback = PeriodicCallback(self.send_werte, 1)
		self.callback.start()
		print ('Connection open')	
	def send_werte(self):
		if not q_live.empty():
			signals, values = q_live.get()
			senden = dict(zip(signals,values))
			print(senden)
			json_send = json.dumps(senden)
			self.write_message(json_send)
			print(q_live.qsize())
			if q_live.qsize() >15:
				with q_live.mutex:
					q_live.queue.clear()
	def on_message(self, empf):
		  print('Daten recievied: ')

	def on_close(self):
		print('Connection closed!')
		self.callback.stop()
class SendWebSocket(tornado.websocket.WebSocketHandler):
    #on_message -> receive data
    #write_message -> send data
    def __init__(self, *args, **keys):
        self.i = 0
        super(SendWebSocket, self).__init__(*args, **keys)

    def open(self):
        self.callback = PeriodicCallback(self._send_message, 1)
        self.callback.start()
        print "WebSocket opend"

    def on_message(self, message):
        print message


    def _send_message(self):
        self.i += 1
        self.write_message(str(self.i))
        if self.i % 20 == 0:
            self.write_message("\n")

    def on_close(self):
        self.callback.stop()
        print "WebSocket closed"
Beispiel #4
0
class SocketHandler(WebSocketHandler):

    def check_origin(self, origin):
        """
        Overrides the parent method to return True for any request, since we are
        working without names

        :returns: bool True
        """
        return True

    def open(self):
        logging.info("Connection open from " + self.request.remote_ip)
        if not self in statusmonitor_open_sockets:
            statusmonitor_open_sockets.append(self) #http://stackoverflow.com/a/19571205
        self.callback = PeriodicCallback(self.send_data, 1000)

        self.callback.start()
        start_callback()

    def send_data(self):
        self.write_message(data_json)
        return

        
    def on_close(self):
        self.callback.stop()
        if self in statusmonitor_open_sockets:
            statusmonitor_open_sockets.remove(self)

        stop_callback()

    def send_update(self):
        pass
Beispiel #5
0
class WebSocketconnectionsHandler(tornado.websocket.WebSocketHandler):

    def __init__(self, *args, **kwargs):
        logger.debug("Creating WebSocket connections handler")
        super(WebSocketconnectionsHandler, self).__init__(*args, **kwargs)
        # No WebSocket connection yet
        self.connected = False
        # We have not counted the connections yet
        self.connections = 0
        # Update the connection count
        self.update()
        # Setup periodic callback via Tornado
        self.periodic_callback = PeriodicCallback(getattr(self, 'update'), 1000)

    def get_connections(self):
        self.connections = 0
        # Get all connections using psutil
        conn = psutil.net_connections('inet')

        if ws.config.CONFIG['PORT'][0] == 'all':
            # If we need the count for all ports we've got it.
            for connection in conn:
                self.connections += 1
        else:
            # Isolate date for the requested ports.
            for port in ws.config.CONFIG['PORT']:
                for connection in conn:
                    if connection.laddr[1] == int(port):
                        self.connections += 1

        return(self.connections)

    def update(self):
        # Save the old number of connections
        old = self.connections
        self.get_connections()

        # Check if the number of connections has changed
        if old != self.connections:
            # Send the new data.
            if self.connected:
                logger.debug(json.dumps({ "connections": self.get_connections() }))
                self.write_message(json.dumps({ "connections": self.get_connections() }))

    def open(self):
        logger.debug(json.dumps({ "connections": self.get_connections() }))
        self.write_message(json.dumps({ "connections": self.get_connections() }))
        # We have a WebSocket connection
        self.connected = True
        self.periodic_callback.start()

    def on_message(self, message):
        logger.debug(json.dumps({ "connections": self.get_connections() }))
        self.write_message(json.dumps({ "connections": self.get_connections() }))

    def on_close(self):
        logger.debug("Connection closed")
        # We no longer have a WebSocket connection.
        self.connected = False
        self.periodic_callback.stop()
Beispiel #6
0
class LoLAPI(object):
    
    def __init__(self, client):
        self.timer = PeriodicCallback(self.status, 1000, IOLoop.instance())
        self.client = client

        self.timer.start()


    def status(self):
        self.client.one.update_status(dict(
            last_updated = datetime.now().strftime("%H:%M:%S %d-%m-%y"),
            game_stats = db.games_data.count(),
            players    = db.users.count(),
            full_games = db.games.count(),
            invalid_games = db.invalid_games.count()
        ))

    def set_user(self, name):
        self.user = User.by_name(name)
        stats = GameStats.find(dict(summoner = self.user.get_dbref()))
        games = [Game.find_one(stat['game_id']) for stat in stats]

        self.client.one.update_games([1, 2, 3, 4, 5, 6, 7]) 
#        self.client.one.update_games(list(stats)) 

    def detach(self):
        self.timer.stop()
Beispiel #7
0
class WebSocketChatHandler(tornado.websocket.WebSocketHandler):

    def initialize(self):
        self.clients = []
        self.callback = PeriodicCallback(self.update_chat, 500)
        self.web_gui_user = self.player_manager.get_by_name(self.get_secure_cookie("player"))

    def open(self, *args):
        self.clients.append(self)
        for msg in self.messages_log:
            self.write_message(msg)
        self.callback.start()

    def on_message(self, message):
        messagejson = json.loads(message)

        self.messages.append(message)
        self.messages_log.append(message)
        self.factory.broadcast("^yellow;<{d}> <^red;{u}^yellow;> {m}".format(
            d=datetime.now().strftime("%H:%M"), u=self.web_gui_user.name, m=messagejson["message"]), 0, "")

    def update_chat(self):
        if len(self.messages) > 0:
            for message in sorted(self.messages):
                for client in self.clients:
                    client.write_message(message)
            del self.messages[0:len(self.messages)]

    def on_close(self):
        self.clients.remove(self)
        self.callback.stop()
Beispiel #8
0
class cpustatus(tornado.websocket.WebSocketHandler):
	#on_message -> receive data
	#write_message -> send data
	
	#index.html
	def open(self):
		#self.i = readData()
		self.i = 0
		self.last = 0
		self.cpu = PeriodicCallback(self._send_cpu, 500) #
		self.cpu.start()

	def on_message(self, message):
		global MainMotorMax
		self.i = int(message)
		MainMotorMax = self.i
		print message

	def _send_cpu(self):
		#self.write_message(str(vmstat()[15]))
		#self.write_message(str(time.time()))
		#self.i = readData()
		if self.i != self.last:
			self.write_message(str(self.i))
			self.last = self.i
			print self.i
	#
	def on_close(self):
		self.cpu.stop()
Beispiel #9
0
class WSHandler(tornado.websocket.WebSocketHandler):
    def initialize(self):
        self.values = [[], []]

    def check_origin(self, origin):
        return True

    def open(self):
        # Send message periodic via socket upon a time interval
        self.initialize()
        self.callback = PeriodicCallback(self.send_values, timeInterval)
        self.callback.start()

    def send_values(self):
        MAX_POINTS = 30
        # Generates random values to send via websocket
        for val in self.values:
            if len(val) < MAX_POINTS:
                val.append(randint(1, 10))
            else:
                val.pop(0)
                val.append(randint(1, 10))
            # self.values1 = [randint(1,10) for i in range(100)]
        message = {"Channel0": self.values[0], "Channel1": self.values[1]}
        # self.write_message(message)
        message = {"DataInfo": [{"id": 40, "sname": "SOG"}]}
        self.write_message(message)

    def on_message(self, message):
        pass

    def on_close(self):
        self.callback.stop()
Beispiel #10
0
class EventedStatsCollector(StatsCollector):
    """
    Stats Collector which allows to subscribe to value changes.
    Update notifications are throttled: interval between updates is no shorter
    than ``accumulate_time``.

    It is assumed that stat keys are never deleted.
    """
    accumulate_time = 0.1  # value is in seconds

    def __init__(self, crawler):
        super(EventedStatsCollector, self).__init__(crawler)
        self.signals = SignalManager(self)
        self._changes = {}
        self._task = PeriodicCallback(self.emit_changes, self.accumulate_time*1000)
        self._task.start()

        # FIXME: this is ugly
        self.crawler = crawler  # used by ArachnadoCrawlerProcess

    def emit_changes(self):
        if self._changes:
            changes, self._changes = self._changes, {}
            self.signals.send_catch_log(stats_changed, changes=changes)

    def open_spider(self, spider):
        super(EventedStatsCollector, self).open_spider(spider)
        self._task.start()

    def close_spider(self, spider, reason):
        super(EventedStatsCollector, self).close_spider(spider, reason)
        self._task.stop()
Beispiel #11
0
class WebSocket(tornado.websocket.WebSocketHandler):
    waiters = set()  # multi clients connect OK
    wdata = ""

    def open(self):
        print("open websocket connection")
        WebSocket.waiters.add(self)  # client add
        self.callback = PeriodicCallback(self._send_message, 30000)  # time out taisaku
        self.callback.start()

    def on_close(self):
        WebSocket.waiters.remove(self)  # client remove
        self.callback.stop()
        print("close websocket connection")

    def on_message(self, message):
        WebSocket.wdata = message
        WebSocket.send_updates(message)

    @classmethod
    def send_updates(cls, message):  # this method is singleton
        print(message + ":connection=" + str(len(cls.waiters)))
        for waiter in cls.waiters:
            try:
                waiter.write_message(message)
            except:
                print("Error sending message", exc_info=True)

    # TIME OUT BOUSHI CALL BACK 30Sec
    def _send_message(self):
        self.write_message("C:POLLING")
Beispiel #12
0
class WebSocketGame(WebSocketHandler):
    def open(self):
        self.game_data = {}
        self.initialize_game()
        self.write_message(self.game_data)

    def on_message(self, message):
        message = json.loads(message)
        if message["type"] == "login":
            self.game_name = message["name"]
            self.game_id = message["game_id"]
            self.loop_callback = PeriodicCallback(self.do_loop, 5000)
        else:
            self.handle_message(message)

    def on_close(self):
        self.loop_callback.stop()
        pass

    def update_status(self, status):
        if status not in ("S", "I", "U", "F"):  # Start, InProgress, Succesful, Fail
            return  # Let's try not to hit the status API with bad values.

        url = "http://localhost:8080/private_api/gametask/{}/{}/{}".format(self.game_name, self.game_id, status)
        request = HTTPRequest(url=url)
        http = AsyncHTTPClient()
        http.fetch(request, self.callback)

    def callback(self, response):
        # Catch any errors.
        print "Callback fired."
        print "HTTP Code: {}".format(response.code)
Beispiel #13
0
class SendWebSocketHandler(tornado.websocket.WebSocketHandler):
    # on_message    recieve data
    # write_message send data
    def open(self):
        self.callback = PeriodicCallback(self._send_message, 10000)
        self.callback.start()
        print("[START] WebSocket")

    def on_message(self, message):
        print("[START] WebSocket on_message")
        print(message)

    def _send_message(self):
        cur = DB.execute("SELECT * FROM lm35dz ORDER BY YMDHHMM DESC")
        rec = cur.fetchone()

        send_value = ""
        if rec == None:
            send_value = "Data Nothing"
        else:
            send_value = "%s %s" % (rec[0], rec[1])

        self.write_message(send_value)

    def on_close(self):
        self.callback.stop()
        print("[ENDED] WebSocket")
Beispiel #14
0
class WebSocket(tornado.websocket.WebSocketHandler):

    def check_origin(self, origin):
        return True

    def on_message(self, message):
        """Evaluates the function pointed to by json-rpc."""

        # Start an infinite loop when this is called
        if message == "read_camera":
            self.camera_loop = PeriodicCallback(self.loop, 10)
            self.camera_loop.start()

        # Extensibility for other methods
        else:
            print("Unsupported function: " + message)

    def loop(self):
        """Sends camera images in an infinite loop."""
        bio = io.BytesIO()

        if args.use_usb:
            _, frame = camera.read()
            img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            img.save(bio, "JPEG")
        else:
            camera.capture(bio, "jpeg", use_video_port=True)

        try:
            self.write_message(base64.b64encode(bio.getvalue()))
        except tornado.websocket.WebSocketClosedError:
            self.camera_loop.stop()
Beispiel #15
0
def broadcast_players(g_id):
    global pcb2
    global cur_g_id
    data_dict = {}

    if pcb2 is None:
        cur_g_id = g_id
        pcb2 = PeriodicCallback(lambda: broadcast_players(g_id), 4000)
        pcb2.start()
    elif cur_g_id != g_id:
        cur_g_id = g_id
        pcb2.stop()
        pcb2 = PeriodicCallback(lambda: broadcast_players(g_id), 4000)
        pcb2.start()

    g_list = ww_redis_db.keys(g_id+"*")
    for v in g_list:
        v = v.decode("utf-8")

        if len(g_list) > 0:
            # find game with least spaces
            for g_key in g_list:
                game = ww_redis_db.hgetall(g_key)
                game = {k.decode('utf8'): v.decode('utf8') for k, v in game.items()}	#convert from byte array to string dict

                players = game['players'].split("|")		# obtain players in current game in the form of uuid

                data_dict[g_key.decode("utf-8")] = str(len(players))

    data_dict["channel"] = "player_list"

    publish_data("player_list:"+g_id, data_dict)

    return data_dict
class SubscriberConnection(ConnectionMixin, SockJSConnection):
    pub_sub = None
    user = None

    def __init__(self, session):
        super(SubscriberConnection, self).__init__(session)
        self.session_store = session_store(self)
        self.pub_sub = get_subscription_provider()

    def _close_invalid_origin(self):
        self.close(4000, message='Invalid origin')

    def on_open(self, request):
        if not set_origin_connection(request, self):
            self._close_invalid_origin()
            return

        super(SubscriberConnection, self).on_open(request)
        if is_heartbeat_enabled():
            self.periodic_callback = PeriodicCallback(self.send_heartbeat, get_heartbeat_frequency())
            self.periodic_callback.start()

    def send_heartbeat(self):
        self.send({'heartbeat': '1'})

    def on_heartbeat(self):
        self.session_store.refresh_all_keys()

    def on_close(self):
        if hasattr(self, 'periodic_callback'):
            self.periodic_callback.stop()
        self.pub_sub.close(self)

    def on_message(self, data):
        if not test_origin(self):
            self._close_invalid_origin()

        try:
            data = self.to_json(data)
            if data == {'heartbeat': '1'}:
                self.on_heartbeat()
                return
            middleware_classes = getattr(settings, 'SWAMP_DRAGON_MIDDLEWARE_CLASSES', None)
            if middleware_classes:
                for middleware_class in middleware_classes:
                    middleware_path, middleware_name = middleware_classes[0].rsplit('.', 1)
                    middleware = getattr(import_module(middleware_path), middleware_name)
                    middleware().process_request(self, data)
            handler = route_handler.get_route_handler(data['route'])
            handler(self).handle(data)
        except Exception as e:
            self.abort_connection()
            raise e

    def abort_connection(self):
        self.close(code=3001, message='Connection aborted')

    def close(self, code=3000, message='Connection closed'):
        self.session.close(code, message)
Beispiel #17
0
class ProgressWidget(Progress):
    """ ProgressBar that uses an IPython ProgressBar widget for the notebook

    See Also
    --------
    progress: User function
    Progress: Super class with most of the logic
    TextProgressBar: Text version suitable for the console
    """
    def __init__(self, keys, scheduler=None, minimum=0, dt=0.1, complete=False):
        keys = {k.key if hasattr(k, 'key') else k for k in keys}
        self.setup_pre(keys, scheduler, minimum, dt, complete)

        def clear_errors(errors):
            for k in errors:
                self.task_erred(None, k, None, True)

        if self.scheduler.loop._thread_ident == threading.current_thread().ident:
            errors = self.setup(keys, complete)
        else:
            errors = sync(self.scheduler.loop, self.setup, keys, complete)

        self.pc = PeriodicCallback(self._update, 1000 * self._dt)

        from ipywidgets import FloatProgress
        self.bar = FloatProgress(min=0, max=1, description='0.0s')
        self.widget = self.bar

        clear_errors(errors)

        self.pc.start()

    def setup(self, keys, complete):
        errors = Progress.setup(self, keys, complete)
        return errors

    def _ipython_display_(self, **kwargs):
        return self.widget._ipython_display_(**kwargs)

    def _start(self):
        return self._update()

    def stop(self, exception=None, key=None):
        Progress.stop(self, exception, key=None)
        with ignoring(AttributeError):
            self.pc.stop()
        self._update()
        if exception:
            self.bar.bar_style = 'danger'
            self.bar.value = 1.0
        elif not self.keys:
            self.bar.bar_style = 'success'

    def _update(self):
        ntasks = len(self.all_keys)
        ndone = ntasks - len(self.keys)
        self.bar.value = ndone / ntasks if ntasks else 1.0
        self.bar.description = format_time(self.elapsed)
Beispiel #18
0
class BrowserSession(object):

    def __init__(self, driver_name, ioloop, timeout=5):
        self._driver = _DRIVERS[driver_name]()
        self._ioloop = ioloop
        self._out_queue = Queue()
        self._in_queue = Queue()
        self._timeout = timeout

        keyword_arguments = {
            "in_queue": self._in_queue,
            "out_queue": self._out_queue
        }
        self._thread = threading.Thread(
            target=self._on_start, kwargs=keyword_arguments)
        self._thread.daemon = True

    def start(self):
        self._driver.set_page_load_timeout(self._timeout)
        self._thread.start()
        return self._driver

    def stop(self):
        self._in_queue.put("stop")
        self._thread.join()
        self._driver.delete_all_cookies()
        if not self._out_queue.empty():
            message = self._out_queue.get()
            raise IOLoopException("Error from IOLoop thread: %s" % (message))

    def _on_start(self, in_queue, out_queue):
        self._start_time = time.time()
        self._periodic_callback = PeriodicCallback(
            self._on_cycle, 100, io_loop=self._ioloop)
        self._periodic_callback.start()
        self._ioloop.start()

    def _on_cycle(self):
        if not self._in_queue.empty():
            message = self._in_queue.get()
            self._on_stop()
            if message != "stop":
                self._out_queue.put("unknown message %s" % (message))

        if time.time() - self._start_time > self._timeout:
            self._ioloop.stop()
            self._out_queue.put("timeout")

    def _on_stop(self):
        self._periodic_callback.stop()
        self._ioloop.stop()

    def __enter__(self):
        return self.start()

    def __exit__(self, *args, **kwargs):
        self.stop()
Beispiel #19
0
class WSHandler(tornado.websocket.WebSocketHandler):

    # Websocket connection has opened with client.
    # Run loop every 5 milliseconds
    def open(self):
        self.callback = PeriodicCallback(self.loop, 5)
        self.callback.start()

        global hands
        self.lastHands = copy.deepcopy(hands)

        print "Connection opened with client."
        print "Wave to start hand tracking."

    # Websocket connection closed with client
    # Stop loop from running.
    def on_close(self):
        self.callback.stop()
        print "Connection closed with client."

    # Loop to handle input from OpenNI / Kinect
    # and send information to client
    def loop(self):

        global hands
        global tracking
        global q

        # Poll for updates from OpenNI / Kinect
        tracking.context.wait_any_update_all() 

        # If gesture is in the queue, send message to client.
        while not q.empty():
            self.write_message(q.get())

        # Send hand coordinates to client.
        if hands != self.lastHands:
            self.lastHands = copy.deepcopy(hands)
            self.send_xy()


    # Hand movement detected, so send message to client.
    def send_xy(self):

        global hands

        message = hands
        message["type"] = "hands"
        self.write_message(message)

    def on_message(self, message):
        pass
    

    # Allow cross origin requests
    def check_origin(self, origin):
        return True
Beispiel #20
0
class SendWebSocket(tornado.websocket.WebSocketHandler):
	#on_message -> receive data
	#write_message -> send data

	#index.html
	def open(self):
		'''
		tty = "/dev/ttyACM0"
		if os.path.exists(tty):
			self.s = serial.Serial(tty,115200)
			#self.s.open()
		enableMotor(self.s, 20)
		setMotor(self.s, 20, 1, 0 ,0)
		'''
		'''
		self.s = mraa.I2c(1)
		self.s.address(0x20)
		self.s.writeReg(0x03,1)
		self.s.writeReg(0x04,1)
		self.s.writeReg(0x06,0)
		'''
		self.i = 0
		self.callback = PeriodicCallback(self._send_message, 1000) #
		self.callback.start()
		
		if(os.path.exists("/dev/video0") and os.system("pidof mjpg_streamer") ==256 ):
			#os.system(' /home/root/mjpg_streamer/mjpg_streamer -i "./input_uvc.so -d /dev/video0 -r 320x240 -f 15 -n" -o "./output_http.so -p 9000 -w ./www" > /dev/null 2>&1 &')
			os.system('/home/root/mjpg-streamer/0start.sh')
			
		print "WebSocket opened" 

	#
	def on_message(self, message):
		'''
		if message == "A key : 1":
			setMotorDuty(self.s, 20, 50)
		else:
			setMotorDuty(self.s, 20, 0)
		'''
		'''
		if message == "A key : 1":
			self.s.writeReg(0x06,50)
		else:
			self.s.writeReg(0x06,0)
		'''
		print message
	
	def _send_message(self):
		self.i += 1
		self.write_message(str(self.i))
		
	#
	def on_close(self):
		self.callback.stop()
		print "WebSocket closed"
Beispiel #21
0
class WebSocketTest(WebSocketHandler):
    def open(self):
        print "Opened websocket."
        self.game_data = {}
        self.loop_callback = PeriodicCallback(self.do_loop, 5000)

        self.game_data["wood"] = 10
        self.game_data["heat"] = 0
        self.game_data["progress"] = 0
        self.write_json(self.game_data)
        self.loop_callback.start()

    def on_message(self, message):
        print "Message: {}".format(message)
        message = json.loads(message)
        if message["type"] == "add_1":
            self.game_data["wood"] += 1
            self.write_json(self.game_data)
        if message["type"] == "add_5":
            self.game_data["wood"] += 5
            self.write_json(self.game_data)

    def do_loop(self):
        if self.game_data["wood"] < 1:
            self.turns_no_wood += 1
            self.game_data["heat"] -= 5 * self.turns_no_wood
            if self.game_data["heat"] < 1:
                self.game_data["heat"] = 0
        else:
            self.turns_no_wood = 0
            self.game_data["heat"] += 5 * self.game_data["wood"]
            self.game_data["wood"] -= self.game_data["heat"] / 50
            if self.game_data["wood"] < 1:
                self.game_data["wood"] = 0

        if self.game_data["heat"] > 2000:
            self.game_data["progress"] += 1

        if self.game_data["heat"] > 2500:
            print "Heat failure."
            self.game_data["fail"] = True

        if self.game_data["progress"] > 9:
            print "Success."
            self.game_data["success"] = True

        self.write_json(self.game_data)

    def on_close(self):
        print "Closed websocket."
        self.loop_callback.stop()

    def write_json(self, message):
        self.write_message(json.dumps(message))
Beispiel #22
0
class PoolQueue(object):
    def __init__(self, *args, **kwargs):
        super(PoolQueue, self).__init__()

        from django.conf import settings as django_settings
        from signalqueue.worker import backends
        from signalqueue import SQ_RUNMODES as runmodes

        self.active = kwargs.get("active", True)
        self.halt = kwargs.get("halt", False)
        self.interval = 1
        self.queue_name = kwargs.get("queue_name", "default")

        self.runmode = runmodes["SQ_ASYNC_DAEMON"]
        self.queues = backends.ConnectionHandler(django_settings.SQ_QUEUES, self.runmode)
        self.signalqueue = self.queues[self.queue_name]
        self.signalqueue.runmode = self.runmode

        # use interval from the config if it exists
        interval = kwargs.get("interval", self.signalqueue.queue_interval)
        if interval is not None:
            self.interval = interval

        if self.interval > 0:
            if self.halt:
                self.shark = PeriodicCallback(self.cueball_scratch, self.interval * 10)
            else:
                self.shark = PeriodicCallback(self.cueball, self.interval * 10)

        if self.active:
            self.shark.start()

    def stop(self):
        self.active = False
        self.shark.stop()

    def rerack(self):
        self.active = True
        self.shark.start()

    def cueball(self):
        # logg.info("Dequeueing signal...")
        with self.signalqueue.log_exceptions():
            self.signalqueue.dequeue()

    def cueball_scratch(self):
        # logg.info("Dequeueing signal...")
        with self.signalqueue.log_exceptions():
            self.signalqueue.dequeue()

        if self.signalqueue.count() < 1:
            print "Queue exhausted, exiting..."
            raise KeyboardInterrupt
class SubscriberConnection(ConnectionMixin, SockJSConnection):
    pub_sub = None

    def __init__(self, session):
        super(SubscriberConnection, self).__init__(session)
        self.session_store = session_store(self)
        self.pub_sub = get_subscription_provider()

    def _close_invalid_origin(self):
        self.close(4000, message='Invalid origin')

    def on_open(self, request):
        if not set_origin_connection(request, self):
            self._close_invalid_origin()
            return

        super(SubscriberConnection, self).on_open(request)
        if is_heartbeat_enabled():
            self.periodic_callback = PeriodicCallback(self.send_heartbeat, get_heartbeat_frequency())
            self.periodic_callback.start()

    def send_heartbeat(self):
        self.send({'heartbeat': '1'})

    def on_heartbeat(self):
        self.session_store.refresh_all_keys()

    def on_close(self):
        if hasattr(self, 'periodic_callback'):
            self.periodic_callback.stop()
        self.pub_sub.close(self)

    def on_message(self, data):
        if not test_origin(self):
            self._close_invalid_origin()

        try:
            data = self.to_json(data)
            if data == {'heartbeat': '1'}:
                self.on_heartbeat()
                return
            handler = route_handler.get_route_handler(data['route'])
            handler(self).handle(data)
        except Exception as e:
            self.abort_connection()
            raise e

    def abort_connection(self):
        self.close(code=3001, message='Connection aborted')

    def close(self, code=3000, message='Connection closed'):
        self.session.close(code, message)
Beispiel #24
0
class ProcessStatsMonitor(object):
    """ A class which emits process stats periodically """

    signal_updated = object()

    def __init__(self, interval=1.0):
        self.signals = SignalManager(self)
        self.process = psutil.Process(os.getpid())
        self.interval = interval
        self._task = PeriodicCallback(self._emit, self.interval*1000)
        self._recent = {}

    def start(self):
        # yappi.start()
        self._task.start()

    def stop(self):
        self._task.stop()
        # stats = yappi.get_func_stats()
        # stats.sort('tsub', 'desc')
        # with open("func-stats.txt", 'wt') as f:
        #     stats.print_all(f, columns={
        #         0: ("name", 80),
        #         1: ("ncall", 10),
        #         2: ("tsub", 8),
        #         3: ("ttot", 8),
        #         4: ("tavg",8)
        #     })
        #
        # pstats = yappi.convert2pstats(stats)
        # pstats.dump_stats("func-stats.prof")

    def get_recent(self):
        return self._recent

    def _emit(self):
        cpu_times = self.process.cpu_times()
        ram_usage = self.process.memory_info()
        stats = {
            'ram_percent': self.process.memory_percent(),
            'ram_rss': ram_usage.rss,
            'ram_vms': ram_usage.vms,
            'cpu_percent': self.process.cpu_percent(),
            'cpu_time_user': cpu_times.user,
            'cpu_time_system': cpu_times.system,
            'num_fds': self.process.num_fds(),
            'context_switches': self.process.num_ctx_switches(),
            'num_threads': self.process.num_threads(),
            'server_time': int(time.time()*1000),
        }
        self._recent = stats
        self.signals.send_catch_log(self.signal_updated, stats=stats)
class WSHandler(tornado.websocket.WebSocketHandler):
    def open(self):
        self.callback = PeriodicCallback(self.send_hello, 2000)
        self.callback.start()

    def send_hello(self):
        self.write_message(json.dumps(ws_push_data))

    def on_message(self, message):
        pass

    def on_close(self):
        self.callback.stop()
Beispiel #26
0
class WSHandler(tornado.websocket.WebSocketHandler):
    def open(self):
        self.callback = PeriodicCallback(self.send_hello, 10)
        self.callback.start()

    def send_hello(self):
        self.write_message("Websocket")

    def on_message(self, message):
        print "Recieved Websocket Reply {}".format(message)

    def on_close(self):
        self.callback.stop()
Beispiel #27
0
class AsyncPopen(object):
    '''Asynchronous version of :class:`subprocess.Popen`.'''
    def __init__(self, *args, **kwargs):
        self.args = args
        self.kwargs = kwargs

        self.on_output = Event()
        self.on_end = Event()

    def run(self):
        self.ioloop = IOLoop.instance()
        (master_fd, slave_fd) = pty.openpty()

        # make stdout, stderr non-blocking
        fcntl.fcntl(master_fd, fcntl.F_SETFL,
                    fcntl.fcntl(master_fd, fcntl.F_GETFL) | os.O_NONBLOCK)

        self.master_fd = master_fd
        self.master = os.fdopen(master_fd)

        # listen to stdout, stderr
        self.ioloop.add_handler(master_fd, self._handle_subprocess_stdout,
                                self.ioloop.READ)

        slave = os.fdopen(slave_fd)
        self.kwargs["stdout"] = slave
        self.kwargs["stderr"] = slave
        self.kwargs["close_fds"] = True
        self.pipe = subprocess.Popen(*self.args, **self.kwargs)

        self.stdin = self.pipe.stdin

        # check for process exit
        self.wait_callback = PeriodicCallback(self._wait_for_end, 250)
        self.wait_callback.start()

    def _handle_subprocess_stdout(self, fd, events):
        if not self.master.closed and (events & IOLoop._EPOLLIN) != 0:
            data = self.master.read()
            self.on_output(data)

        self._wait_for_end(events)

    def _wait_for_end(self, events=0):
        self.pipe.poll()
        if self.pipe.returncode is not None or \
                (events & tornado.ioloop.IOLoop._EPOLLHUP) > 0:
            self.wait_callback.stop()
            self.master.close()
            self.ioloop.remove_handler(self.master_fd)
            self.on_end(self.pipe.returncode)
Beispiel #28
0
class SendWebSocket(tornado.websocket.WebSocketHandler):

  num = 0 #インクリメントする数字
  conn = 0 #接続中のホスト数

  @classmethod
  def update_num(cls):
    cls.num += 1

  @classmethod
  def inclement_conn(cls):
    cls.conn += 1

  @classmethod
  def decrement_conn(cls):
    cls.conn -= 1

  @classmethod
  def get_conn(cls):
    return cls.conn

  def check_origin(self, origin):
    return True

  def open(self):
    self.callback = PeriodicCallback(self._send_message, 400)#メッセージをクライアントに送るコールバック
    self.callback.start()

    self.increment_callback = PeriodicCallback(self.update_num,400)#値をインクリメントするコールバック
    self.inclement_conn()
    if self.get_conn() == 1 :
      self.increment_callback.start()#接続ホスト数が1になった時インクリメントをスタートする

    print('WebSocket opendやで')

  def on_message(self, message):
    print(message)

  def _send_message(self):
    self.write_message(str(self.num) + ' 接続中のホスト数=' + str(self.conn))

  def on_close(self):
    self.callback.stop()

    self.decrement_conn()
    if self.get_conn() == 0:
      self.increment_callback.stop()#接続ホストが0になったらインクリメントをストップ

    print('WebSocket closed')
Beispiel #29
0
class TornadoTelemetryManager(TelemetryManager):  # pylint: disable=W0612
    def __init__(self, ioloop):
        TelemetryManager.__init__(self)
        self.ioloop = ioloop
        self._timer = PeriodicCallback(
            stack_context.wrap(self._start_clean_up_timer),
            self.CLEAN_UP_INTERVAL * self.CLEAN_UP_INTERVAL_MULTIPLIER,
            self.ioloop)
        self._timer.start()

    @tornado.gen.coroutine
    def _start_clean_up_timer(self):
        self.clean_up_telemetry_data()

    def _stop_clean_up_timer(self):
        self._timer.stop()
Beispiel #30
0
class SpitRandomStuff(object):
    def __init__(self, stream, address):
        log.info('Received connection from %s', address)
        self.address = address
        self.stream = stream
        self.stream.set_close_callback(self._on_close)
        self.writer = PeriodicCallback(self._random_stuff, 500)
        self.writer.start()

    def _on_close(self):
        log.info('Closed connection from %s', self.address)
        self.writer.stop()

    def _random_stuff(self):
        output = os.urandom(60)
        self.stream.write(md5(output).hexdigest() + "\n")
Beispiel #31
0
class ServerNewHandler(WebSocketBaseHandler):
    def on_message(self, message):
        self.params.update(json_loads(message))

        # 参数认证
        try:
            args = ['cluster_id', 'name', 'public_ip', 'username', 'passwd']

            self.guarantee(*args)

            for i in args[1:]:
                self.params[i] = self.params[i].strip()

            validate_ip(self.params['public_ip'])

            self.params.update(self.get_lord())
            self.params.update({'owner': self.current_user['id']})
        except Exception as e:
            self.write_message(str(e))
            self.close()
            return

        IOLoop.current().spawn_callback(
            callback=self.handle_msg)  # on_message不能异步, 要实现异步需spawn_callback

    @coroutine
    def handle_msg(self):
        is_deploying = self.redis.hget(DEPLOYING, self.params['public_ip'])
        is_deployed = self.redis.hget(DEPLOYED, self.params['public_ip'])

        # 通知主机添加失败,后续需要将主机添加失败原因进行抽象分类告知用户
        message = {
            'owner':
            self.params.get('owner'),
            'ip':
            self.params['public_ip'],
            'tip':
            '{}'.format(
                self.params.get('lord') if self.params.get('form') ==
                FORM_COMPANY else 0)
        }

        if is_deploying:
            reason = '%s 正在部署' % self.params['public_ip']
            self.write_message(reason)
            self.write_message('failure')
            message['reason'] = reason
            yield self.message_service.notify_server_add_failed(message)
            return

        if is_deployed:
            reason = '%s 之前已部署' % self.params['public_ip']
            self.write_message(reason)
            self.write_message('failure')
            message['reason'] = reason
            yield self.message_service.notify_server_add_failed(message)
            return

        # 保存到redis之前加密
        passwd = self.params['passwd']
        self.params['passwd'] = Aes.encrypt(passwd)

        self.redis.hset(DEPLOYING, self.params['public_ip'],
                        json.dumps(self.params))

        self.period = PeriodicCallback(self.check, 3000)  # 设置定时函数, 3秒
        self.period.start()

        self.params.update({
            'passwd': passwd,
            'cmd': MONITOR_CMD,
            'rt': True,
            'out_func': self.write_message
        })
        _, err = yield self.server_service.remote_ssh(self.params)

        err = [e for e in err
               if not re.search(r'symlink|resolve host', e)]  # 忽略某些错误

        # 部署失败
        if err:
            if err[0] == 'Authentication failed.':
                reason = '认证失败'
                self.write_message(reason)
                message['reason'] = reason
                yield self.message_service.notify_server_add_failed(message)
            self.write_message('failure')
            self.period.stop()
            self.close()

            self.redis.hdel(DEPLOYING, self.params['public_ip'])

    def check(self):
        ''' 检查主机是否上报信息 '''
        result = self.redis.hget(DEPLOYED, self.params['public_ip'])

        if result:
            self.write_message('success')
            self.period.stop()
            self.close()

    def on_close(self):
        if hasattr(self, 'period'):
            self.period.stop()
Beispiel #32
0
class MetadataStorage:
    def __init__(self, server, gc_path, database):
        self.server = server
        database.register_local_namespace(METADATA_NAMESPACE)
        self.mddb = database.wrap_namespace(METADATA_NAMESPACE,
                                            parse_keys=False)
        self.pending_requests = {}
        self.events = {}
        self.busy = False
        self.gc_path = gc_path
        self.prune_cb = PeriodicCallback(self.prune_metadata,
                                         METADATA_PRUNE_TIME)

    def update_gcode_path(self, path):
        if path == self.gc_path:
            return
        self.mddb.clear()
        self.gc_path = path
        if not self.prune_cb.is_running():
            self.prune_cb.start()

    def close(self):
        self.prune_cb.stop()

    def get(self, key, default=None):
        return self.mddb.get(key, default)

    def __getitem__(self, key):
        return self.mddb[key]

    def prune_metadata(self):
        for fname in list(self.mddb.keys()):
            fpath = os.path.join(self.gc_path, fname)
            if not os.path.exists(fpath):
                del self.mddb[fname]
                logging.info(f"Pruned file: {fname}")
                continue

    def _has_valid_data(self, fname, fsize, modified):
        mdata = self.mddb.get(fname, {'size': "", 'modified': 0})
        return mdata['size'] == fsize and mdata['modified'] == modified

    def remove_file(self, fname):
        try:
            del self.mddb[fname]
        except Exception:
            pass

    def parse_metadata(self, fname, fsize, modified, notify=False):
        evt = Event()
        if fname in self.pending_requests or \
                self._has_valid_data(fname, fsize, modified):
            # request already pending or not necessary
            evt.set()
            return evt
        self.pending_requests[fname] = (fsize, modified, notify, evt)
        if self.busy:
            return evt
        self.busy = True
        IOLoop.current().spawn_callback(self._process_metadata_update)
        return evt

    async def _process_metadata_update(self):
        while self.pending_requests:
            fname, (fsize, modified, notify, evt) = \
                self.pending_requests.popitem()
            if self._has_valid_data(fname, fsize, modified):
                evt.set()
                continue
            retries = 3
            while retries:
                try:
                    await self._run_extract_metadata(fname, notify)
                except Exception:
                    logging.exception("Error running extract_metadata.py")
                    retries -= 1
                else:
                    break
            else:
                self.mddb[fname] = {'size': fsize, 'modified': modified}
                logging.info(
                    f"Unable to extract medatadata from file: {fname}")
            evt.set()
        self.busy = False

    async def _run_extract_metadata(self, filename, notify):
        # Escape single quotes in the file name so that it may be
        # properly loaded
        filename = filename.replace("\"", "\\\"")
        cmd = " ".join([
            sys.executable, METADATA_SCRIPT, "-p", self.gc_path, "-f",
            f"\"{filename}\""
        ])
        shell_command = self.server.lookup_plugin('shell_command')
        scmd = shell_command.build_shell_command(cmd, log_stderr=True)
        result = await scmd.run_with_response(timeout=10.)
        if result is None:
            raise self.server.error(f"Metadata extraction error")
        try:
            decoded_resp = json.loads(result.strip())
        except Exception:
            logging.debug(f"Invalid metadata response:\n{result}")
            raise
        path = decoded_resp['file']
        metadata = decoded_resp['metadata']
        if not metadata:
            # This indicates an error, do not add metadata for this
            raise self.server.error("Unable to extract metadata")
        self.mddb[path] = dict(metadata)
        metadata['filename'] = path
        if notify:
            self.server.send_event("file_manager:metadata_update", metadata)
Beispiel #33
0
class State(BaseState):

    NAME = "Redis"

    OK_RESPONSE = "OK"

    def __init__(self, *args, **kwargs):
        super(State, self).__init__(*args, **kwargs)
        self.host = None
        self.port = None
        self.db = None
        self.client = None
        self.connection_check = None

    def initialize(self):
        settings = self.config.get('settings', {})
        self.host = settings.get("host", "localhost")
        self.port = settings.get("port", 6379)
        self.db = settings.get("db", 0)
        self.client = toredis.Client(io_loop=self.io_loop)
        self.client.state = self
        self.connection_check = PeriodicCallback(self.check_connection, 1000)
        self.connect()
        logger.info("Redis State initialized")

    def on_select(self, res):
        if res != self.OK_RESPONSE:
            logger.error("state select database: {0}".format(res))

    def connect(self):
        """
        Connect to Redis.
        Do not even try to connect if State is faked.
        """
        if self.fake:
            return

        try:
            self.client.connect(host=self.host, port=self.port)
        except Exception as e:
            logger.error("error connecting to Redis server: %s" % (str(e)))
        else:
            if self.db and isinstance(self.db, int):
                self.client.select(self.db, callback=self.on_select)

        self.connection_check.stop()
        self.connection_check.start()

    def check_connection(self):
        if not self.client.is_connected():
            logger.info('reconnecting to Redis')
            self.connect()

    @staticmethod
    def get_presence_set_key(project_id, namespace, channel):
        return "centrifuge:presence:set:%s:%s:%s" % (project_id, namespace,
                                                     channel)

    @coroutine
    def add_presence(self,
                     project_id,
                     namespace,
                     channel,
                     uid,
                     user_info,
                     presence_timeout=None):
        """
        Add user's presence with appropriate expiration time.
        Must be called when user subscribes on channel.
        """
        if self.fake:
            raise Return((True, None))
        now = int(time.time())
        expire_at = now + (presence_timeout or self.presence_timeout)
        hash_key = self.get_presence_hash_key(project_id, namespace, channel)
        set_key = self.get_presence_set_key(project_id, namespace, channel)
        try:
            yield Task(self.client.zadd, set_key, {uid: expire_at})
            yield Task(self.client.hset, hash_key, uid, json_encode(user_info))
        except StreamClosedError as e:
            raise Return((None, e))
        else:
            raise Return((True, None))

    @coroutine
    def remove_presence(self, project_id, namespace, channel, uid):
        """
        Remove user's presence from Redis.
        Must be called on disconnects of any kind.
        """
        if self.fake:
            raise Return((True, None))
        hash_key = self.get_presence_hash_key(project_id, namespace, channel)
        set_key = self.get_presence_set_key(project_id, namespace, channel)
        try:
            yield Task(self.client.hdel, hash_key, uid)
            yield Task(self.client.zrem, set_key, uid)
        except StreamClosedError as e:
            raise Return((None, e))
        else:
            raise Return((True, None))

    @coroutine
    def get_presence(self, project_id, namespace, channel):
        """
        Get presence for channel.
        """
        if self.fake:
            raise Return((None, None))
        now = int(time.time())
        hash_key = self.get_presence_hash_key(project_id, namespace, channel)
        set_key = self.get_presence_set_key(project_id, namespace, channel)
        try:
            expired_keys = yield Task(self.client.zrangebyscore, set_key, 0,
                                      now)
            if expired_keys:
                yield Task(self.client.zremrangebyscore, set_key, 0, now)
                yield Task(self.client.hdel, hash_key,
                           [x.decode() for x in expired_keys])
            data = yield Task(self.client.hgetall, hash_key)
        except StreamClosedError:
            raise Return((None, 'presence unavailable'))
        else:
            raise Return((dict_from_list(data), None))

    @coroutine
    def add_history_message(self,
                            project_id,
                            namespace,
                            channel,
                            message,
                            history_size=None):
        """
        Add message to channel's history.
        Must be called when new message has been published.
        """
        if self.fake:
            raise Return((True, None))
        history_size = history_size or self.history_size
        list_key = self.get_history_list_key(project_id, namespace, channel)
        try:
            yield Task(self.client.lpush, list_key, json_encode(message))
            yield Task(self.client.ltrim, list_key, 0, history_size - 1)
        except StreamClosedError as e:
            raise Return((None, e))
        else:
            raise Return((True, None))

    @coroutine
    def get_history(self, project_id, namespace, channel):
        """
        Get a list of last messages for channel.
        """
        if self.fake:
            raise Return((None, None))
        history_list_key = self.get_history_list_key(project_id, namespace,
                                                     channel)
        try:
            data = yield Task(self.client.lrange, history_list_key, 0, -1)
        except StreamClosedError:
            raise Return((None, self.application.INTERNAL_SERVER_ERROR))
        else:
            raise Return(([json_decode(x.decode()) for x in data], None))
Beispiel #34
0
class Authorization:
    def __init__(self, config):
        self.server = config.get_server()
        self.login_timeout = config.getint('login_timeout', 90)
        database = self.server.lookup_component('database')
        database.register_local_namespace('authorized_users', forbidden=True)
        self.users = database.wrap_namespace('authorized_users')
        api_user = self.users.get(API_USER, None)
        if api_user is None:
            self.api_key = uuid.uuid4().hex
            self.users[API_USER] = {
                'username': API_USER,
                'api_key': self.api_key,
                'created_on': time.time()
            }
        else:
            self.api_key = api_user['api_key']
        self.trusted_users = {}
        self.oneshot_tokens = {}
        self.permitted_paths = set()

        # Get allowed cors domains
        self.cors_domains = []
        cors_cfg = config.get('cors_domains', "").strip()
        cds = [d.strip() for d in cors_cfg.split('\n') if d.strip()]
        for domain in cds:
            bad_match = re.search(r"^.+\.[^:]*\*", domain)
            if bad_match is not None:
                raise config.error(
                    f"Unsafe CORS Domain '{domain}'.  Wildcards are not"
                    " permitted in the top level domain.")
            self.cors_domains.append(
                domain.replace(".", "\\.").replace("*", ".*"))

        # Get Trusted Clients
        self.trusted_ips = []
        self.trusted_ranges = []
        self.trusted_domains = []
        trusted_clients = config.get('trusted_clients', "")
        trusted_clients = [
            c.strip() for c in trusted_clients.split('\n') if c.strip()
        ]
        for val in trusted_clients:
            # Check IP address
            try:
                tc = ipaddress.ip_address(val)
            except ValueError:
                pass
            else:
                self.trusted_ips.append(tc)
                continue
            # Check ip network
            try:
                tc = ipaddress.ip_network(val)
            except ValueError:
                pass
            else:
                self.trusted_ranges.append(tc)
                continue
            # Check hostname
            self.trusted_domains.append(val.lower())

        t_clients = "\n".join([str(ip) for ip in self.trusted_ips] +
                              [str(rng) for rng in self.trusted_ranges] +
                              self.trusted_domains)
        c_domains = "\n".join(self.cors_domains)

        logging.info(f"Authorization Configuration Loaded\n"
                     f"Trusted Clients:\n{t_clients}\n"
                     f"CORS Domains:\n{c_domains}")

        self.prune_handler = PeriodicCallback(self._prune_conn_handler,
                                              PRUNE_CHECK_TIME)
        self.prune_handler.start()

        # Register Authorization Endpoints
        self.permitted_paths.add("/access/login")
        self.permitted_paths.add("/access/refresh_jwt")
        self.server.register_endpoint("/access/login", ['POST'],
                                      self._handle_login)
        self.server.register_endpoint("/access/logout", ['POST'],
                                      self._handle_logout)
        self.server.register_endpoint("/access/refresh_jwt", ['POST'],
                                      self._handle_refresh_jwt)
        self.server.register_endpoint("/access/user",
                                      ['GET', 'POST', 'DELETE'],
                                      self._handle_user_request)
        self.server.register_endpoint("/access/user/password", ['POST'],
                                      self._handle_password_reset)
        self.server.register_endpoint("/access/api_key", ['GET', 'POST'],
                                      self._handle_apikey_request,
                                      protocol=['http'])
        self.server.register_endpoint("/access/oneshot_token", ['GET'],
                                      self._handle_token_request,
                                      protocol=['http'])

    async def _handle_apikey_request(self, web_request):
        action = web_request.get_action()
        if action.upper() == 'POST':
            self.api_key = uuid.uuid4().hex
            self.users[f'{API_USER}.api_key'] = self.api_key
        return self.api_key

    async def _handle_token_request(self, web_request):
        ip = web_request.get_ip_address()
        user_info = web_request.get_current_user()
        return self.get_oneshot_token(ip, user_info)

    async def _handle_login(self, web_request):
        return self._login_jwt_user(web_request)

    async def _handle_logout(self, web_request):
        user_info = web_request.get_current_user()
        if user_info is None:
            raise self.server.error("No user logged in")
        username = user_info['username']
        if username in RESERVED_USERS:
            raise self.server.error(
                f"Invalid log out request for user {username}")
        self.users.pop(f"{username}.jwt_secret", None)
        return {"username": username, "action": "user_logged_out"}

    async def _handle_refresh_jwt(self, web_request):
        refresh_token = web_request.get_str('refresh_token')
        user_info = self._decode_jwt(refresh_token, token_type="refresh")
        username = user_info['username']
        secret = bytes.fromhex(user_info['jwt_secret'])
        token = self._generate_jwt(username, secret)
        return {
            'username': username,
            'token': token,
            'action': 'user_jwt_refresh'
        }

    async def _handle_user_request(self, web_request):
        action = web_request.get_action()
        if action == "GET":
            user = web_request.get_current_user()
            if user is None:
                return {
                    'username': None,
                    'created_on': None,
                }
            else:
                return {
                    'username': user['username'],
                    'created_on': user.get('created_on')
                }
        elif action == "POST":
            # Create User
            return self._login_jwt_user(web_request, create=True)
        elif action == "DELETE":
            # Delete User
            return self._delete_jwt_user(web_request)

    async def _handle_password_reset(self, web_request):
        password = web_request.get_str('password')
        new_pass = web_request.get_str('new_password')
        user_info = web_request.get_current_user()
        if user_info is None:
            raise self.server.error("No Current User")
        username = user_info['username']
        if username in RESERVED_USERS:
            raise self.server.error(
                f"Invalid Reset Request for user {username}")
        salt = bytes.fromhex(user_info['salt'])
        hashed_pass = hashlib.pbkdf2_hmac('sha256', password.encode(), salt,
                                          HASH_ITER).hex()
        if hashed_pass != user_info['password']:
            raise self.server.error("Invalid Password")
        new_hashed_pass = hashlib.pbkdf2_hmac('sha256', new_pass.encode(),
                                              salt, HASH_ITER).hex()
        self.users[f'{username}.password'] = new_hashed_pass
        return {'username': username, 'action': "user_password_reset"}

    def _login_jwt_user(self, web_request, create=False):
        username = web_request.get_str('username')
        password = web_request.get_str('password')
        if username in RESERVED_USERS:
            raise self.server.error(f"Invalid Request for user {username}")
        if create:
            if username in self.users:
                raise self.server.error(f"User {username} already exists")
            salt = secrets.token_bytes(32)
            hashed_pass = hashlib.pbkdf2_hmac('sha256', password.encode(),
                                              salt, HASH_ITER).hex()
            user_info = {
                'username': username,
                'password': hashed_pass,
                'salt': salt.hex(),
                'created_on': time.time()
            }
            self.users[username] = user_info
            action = "user_created"
        else:
            if username not in self.users:
                raise self.server.error(f"Unregistered User: {username}")
            user_info = self.users[username]
            salt = bytes.fromhex(user_info['salt'])
            hashed_pass = hashlib.pbkdf2_hmac('sha256', password.encode(),
                                              salt, HASH_ITER).hex()
            action = "user_logged_in"
        if hashed_pass != user_info['password']:
            raise self.server.error("Invalid Password")
        jwt_secret = user_info.get('jwt_secret', None)
        if jwt_secret is None:
            jwt_secret = secrets.token_bytes(32)
            user_info['jwt_secret'] = jwt_secret.hex()
            self.users[username] = user_info
        else:
            jwt_secret = bytes.fromhex(jwt_secret)
        token = self._generate_jwt(username, jwt_secret)
        refresh_token = self._generate_jwt(
            username,
            jwt_secret,
            token_type="refresh",
            exp_time=datetime.timedelta(days=self.login_timeout))
        return {
            'username': username,
            'token': token,
            'refresh_token': refresh_token,
            'action': action
        }

    def _delete_jwt_user(self, web_request):
        password = web_request.get_str('password')
        user_info = web_request.get_current_user()
        if user_info is None:
            raise self.server.error("No Current User")
        username = user_info['username']
        if username in RESERVED_USERS:
            raise self.server.error(f"Invalid request for user {username}")
        salt = bytes.fromhex(user_info['salt'])
        hashed_pass = hashlib.pbkdf2_hmac('sha256', password.encode(), salt,
                                          HASH_ITER).hex()
        if hashed_pass != user_info['password']:
            raise self.server.error("Invalid Password")
        del self.users[username]
        return {"username": username, "action": "user_deleted"}

    def _generate_jwt(self,
                      username,
                      secret,
                      token_type="auth",
                      exp_time=JWT_EXP_TIME):
        curtime = time.time()
        payload = {
            'iss': "Moonraker",
            'iat': curtime,
            'exp': curtime + exp_time.total_seconds(),
            'username': username,
            'token_type': token_type
        }
        enc_header = base64url_encode(json.dumps(JWT_HEADER).encode())
        enc_payload = base64url_encode(json.dumps(payload).encode())
        message = enc_header + b"." + enc_payload
        signature = base64url_encode(hmac.digest(secret, message, "sha256"))
        message += b"." + signature
        return message.decode()

    def _decode_jwt(self, jwt, token_type="auth"):
        parts = jwt.encode().split(b".")
        if len(parts) != 3:
            raise self.server.error(f"Invalid JWT length of {len(parts)}")
        header = json.loads(base64url_decode(parts[0]))
        payload = json.loads(base64url_decode(parts[1]))
        if header != JWT_HEADER:
            raise self.server.error("Invalid JWT header")
        recd_type = payload.get('token_type', "")
        if token_type != recd_type:
            raise self.server.error(
                f"JWT Token type mismatch: Expected {token_type}, "
                f"Recd: {recd_type}", 401)
        if time.time() > payload['exp']:
            raise self.server.error("JWT expired", 401)
        username = payload.get('username')
        user_info = self.users.get(username, None)
        if user_info is None:
            raise self.server.error(
                f"Invalid JWT, no registered user {username}", 401)
        jwt_secret = user_info.get('jwt_secret', None)
        if jwt_secret is None:
            raise self.server.error(
                f"Invalid JWT, user {username} not logged in", 401)
        secret = bytes.fromhex(jwt_secret)
        # Decode and verify signature
        signature = base64url_decode(parts[2])
        calc_sig = hmac.digest(secret, parts[0] + b"." + parts[1], "sha256")
        if signature != calc_sig:
            raise self.server.error("Invalid JWT signature")
        return user_info

    def _prune_conn_handler(self):
        cur_time = time.time()
        for ip, user_info in list(self.trusted_users.items()):
            exp_time = user_info['expires_at']
            if cur_time >= exp_time:
                self.trusted_users.pop(ip, None)
                logging.info(f"Trusted Connection Expired, IP: {ip}")

    def _oneshot_token_expire_handler(self, token):
        self.oneshot_tokens.pop(token, None)

    def get_oneshot_token(self, ip_addr, user):
        token = base64.b32encode(os.urandom(20)).decode()
        ioloop = IOLoop.current()
        hdl = ioloop.call_later(ONESHOT_TIMEOUT,
                                self._oneshot_token_expire_handler, token)
        self.oneshot_tokens[token] = (ip_addr, user, hdl)
        return token

    def _check_json_web_token(self, request):
        auth_token = request.headers.get("Authorization")
        if auth_token is None:
            auth_token = request.headers.get("X-Access-Token")
        if auth_token and auth_token.startswith("Bearer "):
            auth_token = auth_token[7:]
            try:
                return self._decode_jwt(auth_token)
            except Exception as e:
                raise HTTPError(401, str(e))
        return None

    def _check_authorized_ip(self, ip):
        if ip in self.trusted_ips:
            return True
        for rng in self.trusted_ranges:
            if ip in rng:
                return True
        fqdn = socket.getfqdn(str(ip)).lower()
        if fqdn in self.trusted_domains:
            return True
        return False

    def _check_trusted_connection(self, ip):
        if ip is not None:
            curtime = time.time()
            exp_time = curtime + TRUSTED_CONNECTION_TIMEOUT
            if ip in self.trusted_users:
                self.trusted_users[ip]['expires_at'] = exp_time
                return self.trusted_users[ip]
            elif self._check_authorized_ip(ip):
                logging.info(f"Trusted Connection Detected, IP: {ip}")
                self.trusted_users[ip] = {
                    'username': TRUSTED_USER,
                    'password': None,
                    'created_on': curtime,
                    'expires_at': exp_time
                }
                return self.trusted_users[ip]
        return None

    def _check_oneshot_token(self, token, cur_ip):
        if token in self.oneshot_tokens:
            ip_addr, user, hdl = self.oneshot_tokens.pop(token)
            IOLoop.current().remove_timeout(hdl)
            if cur_ip != ip_addr:
                logging.info(f"Oneshot Token IP Mismatch: expected{ip_addr}"
                             f", Recd: {cur_ip}")
                return None
            return user
        else:
            return None

    def check_authorized(self, request):
        if request.path in self.permitted_paths:
            return None

        # Check JSON Web Token
        jwt_user = self._check_json_web_token(request)
        if jwt_user is not None:
            return jwt_user

        try:
            ip = ipaddress.ip_address(request.remote_ip)
        except ValueError:
            logging.exception(
                f"Unable to Create IP Address {request.remote_ip}")
            ip = None

        # Check oneshot access token
        ost = request.arguments.get('token', None)
        if ost is not None:
            ost_user = self._check_oneshot_token(ost[-1].decode(), ip)
            if ost_user is not None:
                return ost_user

        # Check API Key Header
        key = request.headers.get("X-Api-Key")
        if key and key == self.api_key:
            return self.users[API_USER]

        # Check if IP is trusted
        trusted_user = self._check_trusted_connection(ip)
        if trusted_user is not None:
            return trusted_user

        raise HTTPError(401, "Unauthorized")

    def check_cors(self, origin, request=None):
        if origin is None or not self.cors_domains:
            return False
        for regex in self.cors_domains:
            match = re.match(regex, origin)
            if match is not None:
                if match.group() == origin:
                    logging.debug(f"CORS Pattern Matched, origin: {origin} "
                                  f" | pattern: {regex}")
                    self._set_cors_headers(origin, request)
                    return True
                else:
                    logging.debug(f"Partial Cors Match: {match.group()}")
        else:
            # Check to see if the origin contains an IP that matches a
            # current trusted connection
            match = re.search(r"^https?://([^/:]+)", origin)
            if match is not None:
                ip = match.group(1)
                try:
                    ipaddr = ipaddress.ip_address(ip)
                except ValueError:
                    pass
                else:
                    if self._check_authorized_ip(ipaddr):
                        logging.debug(f"Cors request matched trusted IP: {ip}")
                        self._set_cors_headers(origin, request)
                        return True
            logging.debug(f"No CORS match for origin: {origin}\n"
                          f"Patterns: {self.cors_domains}")
        return False

    def _set_cors_headers(self, origin, request):
        if request is None:
            return
        request.set_header("Access-Control-Allow-Origin", origin)
        request.set_header("Access-Control-Allow-Methods",
                           "GET, POST, PUT, DELETE, OPTIONS")
        request.set_header(
            "Access-Control-Allow-Headers",
            "Origin, Accept, Content-Type, X-Requested-With, "
            "X-CRSF-Token, Authorization, X-Access-Token, "
            "X-Api-Key")

    def close(self):
        self.prune_handler.stop()
def test__lifecycle_hooks():
    application = Application()
    handler = HookTestHandler()
    application.add(handler)
    with ManagedServerLoop(application,
                           check_unused_sessions_milliseconds=30) as server:
        # wait for server callbacks to run before we mix in the
        # session, this keeps the test deterministic
        def check_done():
            if len(handler.hooks) == 4:
                server.io_loop.stop()

        server_load_checker = PeriodicCallback(check_done,
                                               1,
                                               io_loop=server.io_loop)
        server_load_checker.start()
        server.io_loop.start()
        server_load_checker.stop()

        # now we create a session
        client_session = pull_session(session_id='test__lifecycle_hooks',
                                      url=url(server),
                                      io_loop=server.io_loop)
        client_doc = client_session.document
        assert len(client_doc.roots) == 1

        server_session = server.get_session('/', client_session.id)
        server_doc = server_session.document
        assert len(server_doc.roots) == 1

        client_session.close()
        # expire the session quickly rather than after the
        # usual timeout
        server_session.request_expiration()

        def on_done():
            server.io_loop.stop()

        server.io_loop.call_later(0.1, on_done)

        server.io_loop.start()

    assert handler.hooks == [
        "server_loaded", "next_tick_server", "timeout_server",
        "periodic_server", "session_created", "next_tick_session", "modify",
        "timeout_session", "periodic_session", "session_destroyed",
        "server_unloaded"
    ]

    client_hook_list = client_doc.roots[0]
    server_hook_list = server_doc.roots[0]
    assert handler.load_count == 1
    assert handler.unload_count == 1
    assert handler.session_creation_async_value == 6
    assert client_doc.title == "Modified"
    assert server_doc.title == "Modified"
    # the client session doesn't see the event that adds "session_destroyed" since
    # we shut down at that point.
    assert client_hook_list.hooks == ["session_created", "modify"]
    assert server_hook_list.hooks == [
        "session_created", "modify", "session_destroyed"
    ]
Beispiel #36
0
class PeriodicCallback(param.Parameterized):
    """
    Periodic encapsulates a periodic callback which will run both
    in tornado based notebook environments and on bokeh server. By
    default the callback will run until the stop method is called,
    but count and timeout values can be set to limit the number of
    executions or the maximum length of time for which the callback
    will run. The callback may also be started and stopped by setting
    the running parameter to True or False respectively.
    """

    callback = param.Callable(doc="""
        The callback to execute periodically.""")

    count = param.Integer(default=None, doc="""
        Number of times the callback will be executed, by default
        this is unlimited.""")

    period = param.Integer(default=500, doc="""
        Period in milliseconds at which the callback is executed.""")

    timeout = param.Integer(default=None, doc="""
        Timeout in milliseconds from the start time at which the callback
        expires.""")

    running = param.Boolean(default=False, doc="""
        Toggles whether the periodic callback is currently running.""")

    def __init__(self, **params):
        super().__init__(**params)
        self._counter = 0
        self._start_time = None
        self._cb = None
        self._updating = False
        self._doc = None

    @param.depends('running', watch=True)
    def _start(self):
        if not self.running or self._updating:
            return
        self.start()

    @param.depends('running', watch=True)
    def _stop(self):
        if self.running or self._updating:
            return
        self.stop()

    @param.depends('period', watch=True)
    def _update_period(self):
        if self._cb:
            self.stop()
            self.start()

    def _periodic_callback(self):
        with edit_readonly(state):
            state.busy = True
        try:
            self.callback()
        finally:
            with edit_readonly(state):
                state.busy = False
        self._counter += 1
        if self.timeout is not None:
            dt = (time.time() - self._start_time) * 1000
            if dt > self.timeout:
                self.stop()
        if self._counter == self.count:
            self.stop()

    @property
    def counter(self):
        """
        Returns the execution count of the periodic callback.
        """
        return self._counter

    def _cleanup(self, session_context):
        self.stop()

    def start(self):
        """
        Starts running the periodic callback.
        """
        if self._cb is not None:
            raise RuntimeError('Periodic callback has already started.')
        if not self.running:
            try:
                self._updating = True
                self.running = True
            finally:
                self._updating = False
        self._start_time = time.time()
        if state.curdoc:
            self._doc = state.curdoc
            self._cb = self._doc.add_periodic_callback(self._periodic_callback, self.period)
        else:
            from tornado.ioloop import PeriodicCallback
            self._cb = PeriodicCallback(self._periodic_callback, self.period)
            self._cb.start()
        try:
            state.on_session_destroyed(self._cleanup)
        except Exception:
            pass

    def stop(self):
        """
        Stops running the periodic callback.
        """
        if self.running:
            try:
                self._updating = True
                self.running = False
            finally:
                self._updating = False
        self._counter = 0
        self._timeout = None
        if self._doc:
            self._doc.remove_periodic_callback(self._cb)
        elif self._cb:
            self._cb.stop()
        self._cb = None
        doc = self._doc or _curdoc()
        if doc:
            doc.session_destroyed_callbacks = {
                cb for cb in doc.session_destroyed_callbacks
                if cb is not self._cleanup
            }
            self._doc = None
Beispiel #37
0
class UpdateManager:
    def __init__(self, config):
        self.server = config.get_server()
        self.config = config
        self.config.read_supplemental_config(SUPPLEMENTAL_CFG_PATH)
        auto_refresh_enabled = config.getboolean('enable_auto_refresh', False)
        self.distro = config.get('distro', "debian").lower()
        if self.distro not in SUPPORTED_DISTROS:
            raise config.error(f"Unsupported distro: {self.distro}")
        self.cmd_helper = CommandHelper(config)
        env = sys.executable
        mooncfg = self.config[f"update_manager static {self.distro} moonraker"]
        self.updaters = {
            "system": PackageUpdater(self.cmd_helper),
            "moonraker": GitUpdater(mooncfg, self.cmd_helper, MOONRAKER_PATH,
                                    env)
        }
        # TODO: Check for client config in [update_manager].  This is
        # deprecated and will be removed.
        client_repo = config.get("client_repo", None)
        if client_repo is not None:
            client_path = config.get("client_path")
            name = client_repo.split("/")[-1]
            self.updaters[name] = WebUpdater(
                {
                    'repo': client_repo,
                    'path': client_path
                }, self.cmd_helper)
        client_sections = self.config.get_prefix_sections(
            "update_manager client")
        for section in client_sections:
            cfg = self.config[section]
            name = section.split()[-1]
            if name in self.updaters:
                raise config.error("Client repo named %s already added" %
                                   (name, ))
            client_type = cfg.get("type")
            if client_type == "git_repo":
                self.updaters[name] = GitUpdater(cfg, self.cmd_helper)
            elif client_type == "web":
                self.updaters[name] = WebUpdater(cfg, self.cmd_helper)
            else:
                raise config.error("Invalid type '%s' for section [%s]" %
                                   (client_type, section))

        self.cmd_request_lock = Lock()
        self.is_refreshing = False

        # Auto Status Refresh
        self.last_auto_update_time = 0
        self.refresh_cb = None
        if auto_refresh_enabled:
            self.refresh_cb = PeriodicCallback(self._handle_auto_refresh,
                                               UPDATE_REFRESH_INTERVAL_MS)
            self.refresh_cb.start()

        self.server.register_endpoint("/machine/update/moonraker", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/klipper", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/system", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/client", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/status", ["GET"],
                                      self._handle_status_request)
        self.server.register_notification("update_manager:update_response")
        self.server.register_notification("update_manager:update_refreshed")

        # Register Ready Event
        self.server.register_event_handler("server:klippy_identified",
                                           self._set_klipper_repo)
        # Initialize GitHub API Rate Limits and configured updaters
        IOLoop.current().spawn_callback(self._initalize_updaters,
                                        list(self.updaters.values()))

    async def _initalize_updaters(self, initial_updaters):
        self.is_refreshing = True
        await self.cmd_helper.init_api_rate_limit()
        for updater in initial_updaters:
            if isinstance(updater, PackageUpdater):
                ret = updater.refresh(False)
            else:
                ret = updater.refresh()
            if asyncio.iscoroutine(ret):
                await ret
        self.is_refreshing = False

    async def _set_klipper_repo(self):
        kinfo = self.server.get_klippy_info()
        if not kinfo:
            logging.info("No valid klippy info received")
            return
        kpath = kinfo['klipper_path']
        env = kinfo['python_path']
        kupdater = self.updaters.get('klipper', None)
        if kupdater is not None and kupdater.repo_path == kpath and \
                kupdater.env == env:
            # Current Klipper Updater is valid
            return
        kcfg = self.config[f"update_manager static {self.distro} klipper"]
        self.updaters['klipper'] = GitUpdater(kcfg, self.cmd_helper, kpath,
                                              env)
        await self.updaters['klipper'].refresh()

    async def _check_klippy_printing(self):
        klippy_apis = self.server.lookup_plugin('klippy_apis')
        result = await klippy_apis.query_objects({'print_stats': None},
                                                 default={})
        pstate = result.get('print_stats', {}).get('state', "")
        return pstate.lower() == "printing"

    async def _handle_auto_refresh(self):
        if await self._check_klippy_printing():
            # Don't Refresh during a print
            logging.info("Klippy is printing, auto refresh aborted")
            return
        cur_time = time.time()
        cur_hour = time.localtime(cur_time).tm_hour
        time_diff = cur_time - self.last_auto_update_time
        # Update packages if it has been more than 12 hours
        # and the local time is between 12AM and 5AM
        if time_diff < MIN_REFRESH_TIME or cur_hour >= MAX_PKG_UPDATE_HOUR:
            # Not within the update time window
            return
        self.last_auto_update_time = cur_time
        vinfo = {}
        need_refresh_all = not self.is_refreshing
        async with self.cmd_request_lock:
            self.is_refreshing = True
            try:
                for name, updater in list(self.updaters.items()):
                    if need_refresh_all:
                        ret = updater.refresh()
                        if asyncio.iscoroutine(ret):
                            await ret
                    if hasattr(updater, "get_update_status"):
                        vinfo[name] = updater.get_update_status()
            except Exception:
                logging.exception("Unable to Refresh Status")
                return
            finally:
                self.is_refreshing = False
        uinfo = self.cmd_helper.get_rate_limit_stats()
        uinfo['version_info'] = vinfo
        uinfo['busy'] = self.cmd_helper.is_update_busy()
        self.server.send_event("update_manager:update_refreshed", uinfo)

    async def _handle_update_request(self, web_request):
        if await self._check_klippy_printing():
            raise self.server.error("Update Refused: Klippy is printing")
        app = web_request.get_endpoint().split("/")[-1]
        if app == "client":
            app = web_request.get('name')
        inc_deps = web_request.get_boolean('include_deps', False)
        if self.cmd_helper.is_app_updating(app):
            return f"Object {app} is currently being updated"
        updater = self.updaters.get(app, None)
        if updater is None:
            raise self.server.error(f"Updater {app} not available")
        async with self.cmd_request_lock:
            self.cmd_helper.set_update_info(app, id(web_request))
            try:
                await updater.update(inc_deps)
            except Exception as e:
                self.cmd_helper.notify_update_response(f"Error updating {app}")
                self.cmd_helper.notify_update_response(str(e),
                                                       is_complete=True)
                raise
            finally:
                self.cmd_helper.clear_update_info()
        return "ok"

    async def _handle_status_request(self, web_request):
        check_refresh = web_request.get_boolean('refresh', False)
        # Don't refresh if a print is currently in progress or
        # if an update is in progress.  Just return the current
        # state
        if self.cmd_helper.is_update_busy() or \
                await self._check_klippy_printing():
            check_refresh = False
        need_refresh = False
        if check_refresh:
            # If there is an outstanding request processing a
            # refresh, we don't need to do it again.
            need_refresh = not self.is_refreshing
            await self.cmd_request_lock.acquire()
            self.is_refreshing = True
        vinfo = {}
        try:
            for name, updater in list(self.updaters.items()):
                await updater.check_initialized(120.)
                if need_refresh:
                    ret = updater.refresh()
                    if asyncio.iscoroutine(ret):
                        await ret
                if hasattr(updater, "get_update_status"):
                    vinfo[name] = updater.get_update_status()
        except Exception:
            raise
        finally:
            if check_refresh:
                self.is_refreshing = False
                self.cmd_request_lock.release()
        ret = self.cmd_helper.get_rate_limit_stats()
        ret['version_info'] = vinfo
        ret['busy'] = self.cmd_helper.is_update_busy()
        return ret

    def close(self):
        self.cmd_helper.close()
        if self.refresh_cb is not None:
            self.refresh_cb.stop()
Beispiel #38
0
class AzurePreemptibleWorkerPlugin(WorkerPlugin):
    """A worker plugin for azure spot instances

    This worker plugin will poll azure's metadata service for preemption notifications.
    When a node is preempted, the plugin will attempt to shutdown gracefully all workers
    on the node.

    This plugin can be used on any worker running on azure spot instances, not just the
    ones created by ``dask-cloudprovider``.

    For more details on azure spot instances see:
    https://docs.microsoft.com/en-us/azure/virtual-machines/linux/scheduled-events

    Parameters
    ----------
    poll_interval_s: int (optional)
        The rate at which the plugin will poll the metadata service in seconds.

        Defaults to ``1``

    metadata_url: str (optional)
        The url of the metadata service to poll.

        Defaults to "http://169.254.169.254/metadata/scheduledevents?api-version=2019-08-01"

    termination_events: List[str] (optional)
        The type of events that will trigger the gracefull shutdown

        Defaults to ``['Preempt', 'Terminate']``

    termination_offset_minutes: int (optional)
        Extra offset to apply to the premption date. This may be negative, to start
        the gracefull shutdown before the ``NotBefore`` date. It can also be positive, to
        start the shutdown after the ``NotBefore`` date, but this is at your own risk.

        Defaults to ``0``

    Examples
    --------

    Let's say you have cluster and a client instance.
    For example using :class:`dask_kubernetes.KubeCluster`

    >>> from dask_kubernetes import KubeCluster
    >>> from distributed import Client
    >>> cluster = KubeCluster()
    >>> client = Client(cluster)

    You can add the worker plugin using the following:

    >>> from dask_cloudprovider.azure import AzurePreemptibleWorkerPlugin
    >>> client.register_worker_plugin(AzurePreemptibleWorkerPlugin())
    """
    def __init__(
        self,
        poll_interval_s=1,
        metadata_url=None,
        termination_events=None,
        termination_offset_minutes=0,
    ):
        self.callback = None
        self.loop = None
        self.worker = None
        self.poll_interval_s = poll_interval_s
        self.metadata_url = metadata_url or AZURE_EVENTS_METADATA_URL
        self.termination_events = termination_events or [
            "Preempt", "Terminate"
        ]
        self.termination_offset = datetime.timedelta(
            minutes=termination_offset_minutes)

        self.terminating = False
        self.not_before = None
        self._session = None
        self._lock = None

    async def _is_terminating(self):
        preempt_started = False
        async with self._session.get(self.metadata_url) as response:
            try:
                data = await response.json()
            # Sometime azure responds with text/plain mime type
            except aiohttp.ContentTypeError:
                return
            # Sometimes the response doesn't contain the Events key
            events = data.get("Events", [])
            if events:
                logger.debug("Worker {}, got metadata events {}".format(
                    self.worker.name, events))
            for evt in events:
                event_type = evt["EventType"]
                if event_type not in self.termination_events:
                    continue

                event_status = evt.get("EventStatus")
                if event_status == "Started":
                    logger.info("Worker {}, node preemption started".format(
                        self.worker.name))
                    preempt_started = True
                    break

                not_before = evt.get("NotBefore")
                if not not_before:
                    continue

                not_before = datetime.datetime.strptime(
                    not_before, "%a, %d %b %Y %H:%M:%S GMT")
                if self.not_before is None:
                    logger.info(
                        "Worker {}, node deletion scheduled not before {}".
                        format(self.worker.name, self.not_before))
                    self.not_before = not_before
                    break
                if self.not_before < not_before:
                    logger.info(
                        "Worker {}, node deletion re-scheduled not before {}".
                        format(self.worker.name, not_before))
                    self.not_before = not_before
                    break

        return preempt_started or (self.not_before and
                                   (self.not_before + self.termination_offset <
                                    datetime.datetime.utcnow()))

    async def poll_status(self):
        if self.terminating:
            return
        if self._session is None:
            self._session = aiohttp.ClientSession(headers={"Metadata": "true"})
        if self._lock is None:
            self._lock = asyncio.Lock()

        async with self._lock:
            is_terminating = await self._is_terminating()
            if not is_terminating:
                return

            logger.info(
                "Worker {}, node is being deleted, attempting graceful shutdown"
                .format(self.worker.name))
            self.terminating = True
            await self._session.close()
            await self.worker.close_gracefully()

    def setup(self, worker):
        self.worker = worker
        self.loop = IOLoop.current()
        self.callback = PeriodicCallback(self.poll_status,
                                         callback_time=self.poll_interval_s *
                                         1_000)
        self.loop.add_callback(self.callback.start)
        logger.debug("Worker {}, registering preemptible plugin".format(
            self.worker.name))

    def teardown(self, worker):
        logger.debug("Worker {}, tearing down plugin".format(self.worker.name))
        if self.callback:
            self.callback.stop()
            self.callback = None
class Client(object):
    """Kodi Connect Websocket Connection"""
    def __init__(self, url, kodi, handler):
        self.url = url
        self.websocket = None
        self.connected = False
        self.should_stop = False

        self.kodi = kodi
        self.handler = handler

        self.periodic = PeriodicCallback(self.periodic_callback, 20000)

    def start(self):
        """Start IO loop and try to connect to the server"""
        self.connect()
        self.periodic.start()
        IOLoop.current().start()

    def stop(self):
        """Stop IO loop"""
        self.should_stop = True
        if self.websocket is not None:
            self.websocket.close()
        self.periodic.stop()
        IOLoop.current().stop()

    @gen.coroutine
    def connect(self):
        """Connect to the server and update connection to websocket"""
        email = __addon__.getSetting('email')
        secret = __addon__.getSetting('secret')
        if not email or not secret:
            logger.debug('Email and/or secret not defined, not connecting')
            return

        logger.debug('trying to connect')
        try:
            request = HTTPRequest(self.url,
                                  auth_username=email,
                                  auth_password=secret)

            self.websocket = yield websocket_connect(request)
        except Exception as ex:
            logger.debug('connection error: {}'.format(str(ex)))
            self.websocket = None
            notification(strings.FAILED_TO_CONNECT,
                         level='error',
                         tag='connection')
        else:
            logger.debug('Connected')
            self.connected = True
            notification(strings.CONNECTED, tag='connection')
            self.run()

    @gen.coroutine
    def run(self):
        """Main loop handling incomming messages"""
        while True:
            message_str = yield self.websocket.read_message()
            if message_str is None:
                logger.debug('Connection closed')
                self.websocket = None
                notification(strings.DISCONNECTED,
                             level='warn',
                             tag='connection')
                break

            try:
                message = json.loads(message_str)
                logger.debug(message)
                data = message['data']

                response_data = self.handler.handler(data)
            except Exception as ex:
                logger.error('Handler failed: {}'.format(str(ex)))
                response_data = {"status": "error", "error": "Unknown error"}

            self.websocket.write_message(
                json.dumps({
                    "correlationId": message['correlationId'],
                    "data": response_data
                }))

    def periodic_callback(self):
        """Periodic callback"""
        logger.debug('periodic_callback')
        if self.websocket is None:
            self.connect()
        else:
            self.websocket.write_message(json.dumps({"ping": "pong"}))

        try:
            self.kodi.update_cache()
        except Exception as ex:
            logger.error('Failed to update Kodi library: {}'.format(str(ex)))
Beispiel #40
0
class ProcStats:
    def __init__(self, config: ConfigHelper) -> None:
        self.server = config.get_server()
        self.ioloop = IOLoop.current()
        self.stat_update_cb = PeriodicCallback(
            self._handle_stat_update, STAT_UPDATE_TIME_MS)  # type: ignore
        self.vcgencmd: Optional[shell_command.ShellCommand] = None
        if os.path.exists(VC_GEN_CMD_FILE):
            logging.info("Detected 'vcgencmd', throttle checking enabled")
            shell_cmd: shell_command.ShellCommandFactory
            shell_cmd = self.server.load_component(config, "shell_command")
            self.vcgencmd = shell_cmd.build_shell_command(
                "vcgencmd get_throttled")
            self.server.register_notification("proc_stats:cpu_throttled")
        else:
            logging.info("Unable to find 'vcgencmd', throttle checking "
                         "disabled")
        self.temp_file = pathlib.Path(TEMPERATURE_PATH)
        self.smaps = pathlib.Path(STATM_FILE_PATH)
        self.server.register_endpoint(
            "/machine/proc_stats", ["GET"], self._handle_stat_request)
        self.server.register_event_handler(
            "server:klippy_shutdown", self._handle_shutdown)
        self.server.register_notification("proc_stats:proc_stat_update")
        self.proc_stat_queue: Deque[Dict[str, Any]] = deque(maxlen=30)
        self.last_update_time = time.time()
        self.last_proc_time = time.process_time()
        self.throttle_check_lock = Lock()
        self.total_throttled: int = 0
        self.last_throttled: int = 0
        self.update_sequence: int = 0
        self.stat_update_cb.start()

    async def _handle_stat_request(self,
                                   web_request: WebRequest
                                   ) -> Dict[str, Any]:
        ts: Optional[Dict[str, Any]] = None
        if self.vcgencmd is not None:
            ts = await self._check_throttled_state()
        return {
            'moonraker_stats': list(self.proc_stat_queue),
            'throttled_state': ts,
            'cpu_temp': self._get_cpu_temperature()
        }

    async def _handle_shutdown(self) -> None:
        msg = "\nMoonraker System Usage Statistics:"
        for stats in self.proc_stat_queue:
            msg += f"\n{self._format_stats(stats)}"
        msg += f"\nCPU Temperature: {self._get_cpu_temperature()}"
        logging.info(msg)
        if self.vcgencmd is not None:
            ts = await self._check_throttled_state()
            logging.info(f"Throttled Flags: {' '.join(ts['flags'])}")

    async def _handle_stat_update(self) -> None:
        update_time = time.time()
        proc_time = time.process_time()
        time_diff = update_time - self.last_update_time
        usage = round((proc_time - self.last_proc_time) / time_diff * 100, 2)
        mem, mem_units = self._get_memory_usage()
        cpu_temp = self._get_cpu_temperature()
        result = {
            "time": update_time,
            "cpu_usage": usage,
            "memory": mem,
            "mem_units": mem_units,
        }
        self.proc_stat_queue.append(result)
        self.server.send_event("proc_stats:proc_stat_update", {
            'moonraker_stats': result,
            'cpu_temp': cpu_temp
        })
        self.last_update_time = update_time
        self.last_proc_time = proc_time
        self.update_sequence += 1
        if self.update_sequence == THROTTLE_CHECK_INTERVAL:
            self.update_sequence = 0
            if self.vcgencmd is not None:
                ts = await self._check_throttled_state()
                cur_throttled = ts['bits']
                if cur_throttled & ~self.total_throttled:
                    self.server.add_log_rollover_item(
                        'throttled', f"CPU Throttled Flags: {ts['flags']}")
                if cur_throttled != self.last_throttled:
                    self.server.send_event("proc_stats:cpu_throttled", ts)
                self.last_throttled = cur_throttled
                self.total_throttled |= cur_throttled

    async def _check_throttled_state(self) -> Dict[str, Any]:
        async with self.throttle_check_lock:
            assert self.vcgencmd is not None
            try:
                resp = await self.vcgencmd.run_with_response(
                    timeout=.5, log_complete=False)
                ts = int(resp.strip().split("=")[-1], 16)
            except Exception:
                return {'bits': 0, 'flags': ["?"]}
            flags = []
            for flag, desc in THROTTLED_FLAGS.items():
                if flag & ts:
                    flags.append(desc)
            return {'bits': ts, 'flags': flags}

    def _get_memory_usage(self) -> Tuple[Optional[int], Optional[str]]:
        try:
            mem_data = self.smaps.read_text()
            rss_match = re.search(r"Rss:\s+(\d+)\s+(\w+)", mem_data)
            if rss_match is None:
                return None, None
            mem = int(rss_match.group(1))
            units = rss_match.group(2)
        except Exception:
            return None, None
        return mem, units

    def _get_cpu_temperature(self) -> Optional[float]:
        temp = None
        if self.temp_file.exists():
            try:
                res = int(self.temp_file.read_text().strip())
                temp = res / 1000.
            except Exception:
                return None
        return temp

    def _format_stats(self, stats: Dict[str, Any]) -> str:
        return f"System Time: {stats['time']:2f}, " \
               f"Usage: {stats['cpu_usage']}%, " \
               f"Memory: {stats['memory']} {stats['mem_units']}"

    def close(self) -> None:
        self.stat_update_cb.stop()
Beispiel #41
0
class JobScheduler:
    """a FIFO scheduler"""
    def __init__(self,
                 batch,
                 exit_on_finish,
                 interval,
                 executor_manager: ExecutorManager,
                 callbacks=None):
        self.batch = batch
        self.exit_on_finish = exit_on_finish
        self.executor_manager = executor_manager
        self.callbacks = callbacks if callbacks is not None else []

        self._io_loop_instance = None

        self._timer = PeriodicCallback(self.attempt_scheduling, interval)

        self._n_skipped = 0
        self._n_allocated = 0
        self._selected_jobs = []

    @property
    def n_skipped(self):
        return self._n_skipped

    @property
    def n_allocated(self):
        return self._n_allocated

    @property
    def interval(self):
        return self._timer.callback_time

    def start(self):
        self.executor_manager.prepare()
        self._timer.start()
        self.batch.start_time = time.time()

        # stats finished jobs
        for job in self.batch.jobs:
            job_status = self.batch.get_job_status(job.name)
            if job_status != ShellJob.STATUS_INIT:
                logger.info(
                    f"job '{job.name}' status is '{job_status}', skip run.")
                self._n_skipped = self.n_skipped + 1
            else:
                self._selected_jobs.append(job)

        for callback in self.callbacks:
            callback.on_start(self.batch)

        # run in io loop
        self._io_loop_instance = tornado.ioloop.IOLoop.instance()
        logger.info('starting io loop')
        self._io_loop_instance.start()
        logger.info('exited io loop')

    def stop(self):
        if self._io_loop_instance is not None:
            self._io_loop_instance.add_callback(
                self._io_loop_instance.stop)  # let ioloop stop itself
            # self._io_loop_instance.stop() # This is not work for another Thread to stop the ioloop
            logger.info("add a stop callback to ioloop")
        else:
            raise RuntimeError("Not started yet")

    def kill_job(self, job_name):
        # checkout job
        job: ShellJob = self.batch.get_job_by_name(job_name)
        if job is None:
            raise ValueError(f'job {job_name} does not exists ')
        job_status = self.batch.get_job_status(job.name)

        logger.info(
            f"trying kill job {job_name}, it's status is {job_status} ")

        # check job status

        if job_status != job.STATUS_RUNNING:
            raise RuntimeError(
                f"job {job_name} in not in {job.STATUS_RUNNING} status but is {job_status} "
            )

        # find executor and kill
        em = self.executor_manager
        executor = em.get_executor(job)
        logger.info(f"find executor {executor} of job {job_name}")
        if executor is not None:
            em.kill_executor(executor)
            logger.info(f"write failed status file for {job_name}")
            self._change_job_status(job, job.STATUS_FAILED)
        else:
            raise ValueError(f"no executor found for job {job.name}")

    def _change_job_status(self, job: ShellJob, next_status):
        self.change_job_status(self.batch, job, next_status)

    @staticmethod
    def change_job_status(batch: Batch, job: ShellJob, next_status):
        current_status = batch.get_job_status(job_name=job.name)
        target_status_file = batch.job_status_file_path(job_name=job.name,
                                                        status=next_status)

        def touch(f_path):
            with open(f_path, 'w') as f:
                pass

        if next_status == job.STATUS_INIT:
            raise ValueError(f"can not change to {next_status} ")
        elif next_status == job.STATUS_RUNNING:
            if current_status != job.STATUS_INIT:
                raise ValueError(
                    f"only job in {job.STATUS_INIT} can change to {next_status}"
                )

            # job.set_status(next_status)
            touch(target_status_file)

        elif next_status in job.FINAL_STATUS:
            if current_status != job.STATUS_RUNNING:
                raise ValueError(
                    f"only job in {job.STATUS_RUNNING} can change to "
                    f"{next_status} but now is {current_status}")
            # remove running status
            os.remove(
                batch.job_status_file_path(job_name=job.name,
                                           status=job.STATUS_RUNNING))

            # job.set_status(next_status)
            touch(target_status_file)
            reload_status = batch.get_job_status(job_name=job.name)
            assert reload_status == next_status, f"change job status failed, current status is {reload_status}," \
                                                 f" expected status is {next_status}"

        else:
            raise ValueError(f"unknown status {next_status}")

    def _release_executors(self, executor_manager):
        finished = []
        for executor in executor_manager.waiting_executors():
            executor: ShellExecutor = executor
            if executor.status() in ShellJob.FINAL_STATUS:
                finished.append(executor)

        for finished_executor in finished:
            executor_status = finished_executor.status()
            job = finished_executor.job
            logger.info(
                f"job {job.name} finished with status {executor_status}")
            self._change_job_status(job, finished_executor.status())
            job.end_time = time.time()  # update end time
            executor_manager.release_executor(finished_executor)
            self._handle_job_finished(job, finished_executor, job.elapsed)

    def _handle_callbacks(self, func):
        for callback in self.callbacks:
            try:
                callback: BatchCallback = callback
                func(callback)
            except Exception as e:
                logger.warning("handle callback failed", e)

    def _handle_job_start(self, job, executor):
        def f(callback):
            callback.on_job_start(self.batch, job, executor)

        self._handle_callbacks(f)

    def _handle_job_finished(self, job, executor, elapsed):
        def f(callback):
            callback.on_job_finish(self.batch, job, executor, elapsed)

        self._handle_callbacks(f)

    def _run_jobs(self, executor_manager):
        jobs = self._selected_jobs
        for job in jobs:
            if self.batch.get_job_status(job.name) != job.STATUS_INIT:
                # logger.info(f"job '{job.name}' status is {job.status}, skip run")
                continue
            # logger.debug(f'trying to alloc resource for job {job.name}')
            try:
                executor = executor_manager.alloc_executor(job)
            except NoResourceException:
                # logger.debug(f"no enough resource for job {job.name} , wait for resource to continue ...")
                break
            except Exception as e:
                # skip the job, and do not clean the executor
                self._change_job_status(job,
                                        job.STATUS_FAILED)  # TODO on job break
                logger.exception(
                    f"failed to alloc resource for job '{job.name}' ", e)
                continue

            self._n_allocated = self.n_allocated + 1
            job.start_time = time.time()  # update start time
            self._handle_job_start(job, executor)
            process_msg = f"{len(executor_manager.allocated_executors())}/{len(jobs)}"
            logger.info(
                f'allocated resource for job {job.name}({process_msg}), data dir at {job.data_dir_path}'
            )
            self._change_job_status(job, job.STATUS_RUNNING)

            try:
                executor.run()
            except Exception as e:
                logger.exception(f"failed to run job '{job.name}' ", e)
                self._change_job_status(job,
                                        job.STATUS_FAILED)  # TODO on job break
                executor_manager.release_executor(executor)
                continue
            finally:
                pass

    def _handle_on_finished(self):
        for callback in self.callbacks:
            callback: BatchCallback = callback
            callback.on_finish(self.batch, self.batch.elapsed)

    def attempt_scheduling(self):  # attempt_scheduling
        # check all jobs finished
        job_finished = self.batch.is_finished()
        if job_finished:
            self.batch.end_time = time.time()
            batch_summary = json.dumps(self.batch.summary())
            logger.info("all jobs finished, stop scheduler:\n" + batch_summary)
            self._timer.stop()  # stop the timer
            if self.exit_on_finish:
                self.stop()
            self._handle_on_finished()
            return

        self._release_executors(self.executor_manager)
        self._run_jobs(self.executor_manager)
def test__lifecycle_hooks(ManagedServerLoop) -> None:
    application = Application()
    handler = HookTestHandler()
    application.add(handler)
    with ManagedServerLoop(application, check_unused_sessions_milliseconds=30) as server:
        # wait for server callbacks to run before we mix in the
        # session, this keeps the test deterministic
        def check_done():
            if len(handler.hooks) == 4:
                server.io_loop.stop()
        server_load_checker = PeriodicCallback(check_done, 1)
        server_load_checker.start()
        server.io_loop.start()
        server_load_checker.stop()

        # now we create a session
        client_session = pull_session(session_id='test__lifecycle_hooks',
                                      url=url(server),
                                      io_loop=server.io_loop)
        client_doc = client_session.document
        assert len(client_doc.roots) == 1

        server_session = server.get_session('/', client_session.id)
        server_doc = server_session.document
        assert len(server_doc.roots) == 1

        # we have to capture these here for examination later, since after
        # the session is closed, doc.roots will be emptied
        client_hook_list = client_doc.roots[0]
        server_hook_list = server_doc.roots[0]

        client_session.close()
        # expire the session quickly rather than after the
        # usual timeout
        server_session.request_expiration()

        def on_done():
            server.io_loop.stop()

        server.io_loop.call_later(0.1, on_done)

        server.io_loop.start()

    assert handler.hooks == ["server_loaded",
                             "next_tick_server",
                             "timeout_server",
                             "periodic_server",
                             "session_created",
                             "modify",
                             "next_tick_session",
                             "timeout_session",
                             "periodic_session",
                             "session_destroyed",
                             "server_unloaded"]

    assert handler.load_count == 1
    assert handler.unload_count == 1
    # this is 3 instead of 6 because locked callbacks on destroyed sessions
    # are turned into no-ops
    assert handler.session_creation_async_value == 3
    assert client_doc.title == "Modified"
    assert server_doc.title == "Modified"
    # only the handler sees the event that adds "session_destroyed" since
    # the session is shut down at that point.
    assert client_hook_list.hooks == ["session_created", "modify"]
    assert server_hook_list.hooks == ["session_created", "modify"]
Beispiel #43
0
class JsonStatusServer(AsyncDeviceServer):
    VERSION_INFO = ("reynard-eff-jsonstatusserver-api", 0, 1)
    BUILD_INFO = ("reynard-eff-jsonstatusserver-implementation", 0, 1, "rc1")

    def __init__(self,
                 server_host,
                 server_port,
                 mcast_group=JSON_STATUS_MCAST_GROUP,
                 mcast_port=JSON_STATUS_PORT,
                 parser=EFF_JSON_CONFIG,
                 dummy=False):
        self._mcast_group = mcast_group
        self._mcast_port = mcast_port
        self._parser = parser
        self._dummy = dummy
        if not dummy:
            self._catcher_thread = StatusCatcherThread()
        else:
            self._catcher_thread = None
        self._monitor = None
        self._updaters = {}
        self._controlled = set()
        super(JsonStatusServer, self).__init__(server_host, server_port)

    @coroutine
    def _update_sensors(self):
        log.debug("Updating sensor values")
        data = self._catcher_thread.data
        if data is None:
            log.warning("Catcher thread has not received any data yet")
            return
        for name, params in self._parser.items():
            if name in self._controlled:
                continue
            if "updater" in params:
                self._sensors[name].set_value(params["updater"](data))

    def start(self):
        """start the server"""
        super(JsonStatusServer, self).start()
        if not self._dummy:
            self._catcher_thread.start()
            self._monitor = PeriodicCallback(self._update_sensors,
                                             1000,
                                             io_loop=self.ioloop)
            self._monitor.start()

    def stop(self):
        """stop the server"""
        if not self._dummy:
            if self._monitor:
                self._monitor.stop()
            self._catcher_thread.stop()
        return super(JsonStatusServer, self).stop()

    @request()
    @return_reply(Str())
    def request_xml(self, req):
        """request an XML version of the status message"""
        def make_elem(parent, name, text):
            child = etree.Element(name)
            child.text = text
            parent.append(child)

        @coroutine
        def convert():
            try:
                root = etree.Element("TelescopeStatus",
                                     attrib={"timestamp": str(time.time())})
                for name, sensor in self._sensors.items():
                    child = etree.Element("TelStat")
                    make_elem(child, "Name", name)
                    make_elem(child, "Value", str(sensor.value()))
                    make_elem(child, "Status", str(sensor.status()))
                    make_elem(child, "Type", self._parser[name]["type"])
                    if "units" in self._parser[name]:
                        make_elem(child, "Units", self._parser[name]["units"])
                    root.append(child)
            except Exception as error:
                req.reply("ok", str(error))
            else:
                req.reply("ok", etree.tostring(root))

        self.ioloop.add_callback(convert)
        raise AsyncReply

    @request()
    @return_reply(Str())
    def request_json(self, req):
        """request an JSON version of the status message"""
        return ("ok", self.as_json())

    def as_json(self):
        """Convert status sensors to JSON object"""
        out = {}
        for name, sensor in self._sensors.items():
            out[name] = str(sensor.value())
        return json.dumps(out)

    @request(Str())
    @return_reply(Str())
    def request_sensor_control(self, req, name):
        """take control of a given sensor value"""
        if name not in self._sensors:
            return ("fail", "No sensor named '{0}'".format(name))
        else:
            self._controlled.add(name)
            return ("ok", "{0} under user control".format(name))

    @request()
    @return_reply(Str())
    def request_sensor_control_all(self, req):
        """take control of all sensors value"""
        for name, sensor in self._sensors.items():
            self._controlled.add(name)
        return ("ok",
                "{0} sensors under user control".format(len(self._controlled)))

    @request()
    @return_reply(Int())
    def request_sensor_list_controlled(self, req):
        """List all controlled sensors"""
        count = len(self._controlled)
        for name in list(self._controlled):
            req.inform("{0} -- {1}".format(name, self._sensors[name].value()))
        return ("ok", count)

    @request(Str())
    @return_reply(Str())
    def request_sensor_release(self, req, name):
        """release a sensor from user control"""
        if name not in self._sensors:
            return ("fail", "No sensor named '{0}'".format(name))
        else:
            self._controlled.remove(name)
            return ("ok", "{0} released from user control".format(name))

    @request()
    @return_reply(Str())
    def request_sensor_release_all(self, req):
        """take control of all sensors value"""
        self._controlled = set()
        return ("ok", "All sensors released")

    @request(Str(), Str())
    @return_reply(Str())
    def request_sensor_set(self, req, name, value):
        """Set the value of a sensor"""
        if name not in self._sensors:
            return ("fail", "No sensor named '{0}'".format(name))
        if name not in self._controlled:
            return ("fail", "Sensor '{0}' not under user control".format(name))
        try:
            param = self._parser[name]
            value = TYPE_CONVERTER[param["type"]](value)
            self._sensors[name].set_value(value)
        except Exception as error:
            return ("fail", str(error))
        else:
            return ("ok", "{0} set to {1}".format(name,
                                                  self._sensors[name].value()))

    def setup_sensors(self):
        """Set up basic monitoring sensors.
        """
        for name, params in self._parser.items():
            if params["type"] == "float":
                sensor = Sensor.float(name,
                                      description=params["description"],
                                      unit=params.get("units", None),
                                      default=params.get("default", 0.0),
                                      initial_status=Sensor.UNKNOWN)
            elif params["type"] == "string":
                sensor = Sensor.string(name,
                                       description=params["description"],
                                       default=params.get("default", ""),
                                       initial_status=Sensor.UNKNOWN)
            elif params["type"] == "int":
                sensor = Sensor.integer(name,
                                        description=params["description"],
                                        default=params.get("default", 0),
                                        unit=params.get("units", None),
                                        initial_status=Sensor.UNKNOWN)
            elif params["type"] == "bool":
                sensor = Sensor.boolean(name,
                                        description=params["description"],
                                        default=params.get("default", False),
                                        initial_status=Sensor.UNKNOWN)
            else:
                raise Exception("Unknown sensor type '{0}' requested".format(
                    params["type"]))
            self.add_sensor(sensor)
Beispiel #44
0
class WorkStealing(SchedulerPlugin):
    def __init__(self, scheduler):
        self.scheduler = scheduler
        self.stealable_all = [set() for i in range(15)]
        self.stealable = dict()
        self.key_stealable = dict()
        self.stealable_unknown_durations = defaultdict(set)

        self.cost_multipliers = [1 + 2**(i - 6) for i in range(15)]
        self.cost_multipliers[0] = 1

        for worker in scheduler.workers:
            self.add_worker(worker=worker)

        self._pc = PeriodicCallback(callback=self.balance,
                                    callback_time=100,
                                    io_loop=self.scheduler.loop)
        self.scheduler.loop.add_callback(self._pc.start)
        self.scheduler.plugins.append(self)
        self.scheduler.extensions['stealing'] = self
        self.scheduler.events['stealing'] = deque(maxlen=100000)
        self.count = 0

    @property
    def log(self):
        return self.scheduler.events['stealing']

    def add_worker(self, scheduler=None, worker=None):
        self.stealable[worker] = [set() for i in range(15)]

    def remove_worker(self, scheduler=None, worker=None):
        del self.stealable[worker]

    def teardown(self):
        self._pc.stop()

    def transition(self,
                   key,
                   start,
                   finish,
                   compute_start=None,
                   compute_stop=None,
                   *args,
                   **kwargs):
        if finish == 'processing':
            self.put_key_in_stealable(key)

        if start == 'processing':
            self.remove_key_from_stealable(key)
            if finish == 'memory':
                ks = key_split(key)
                if ks in self.stealable_unknown_durations:
                    for k in self.stealable_unknown_durations.pop(ks):
                        if self.scheduler.task_state[k] == 'processing':
                            self.put_key_in_stealable(k, split=ks)

    def put_key_in_stealable(self, key, split=None):
        worker = self.scheduler.rprocessing[key]
        cost_multiplier, level = self.steal_time_ratio(key, split=split)
        if cost_multiplier is not None:
            self.stealable_all[level].add(key)
            self.stealable[worker][level].add(key)
            self.key_stealable[key] = (worker, level)

    def remove_key_from_stealable(self, key):
        result = self.key_stealable.pop(key, None)
        if result is not None:
            worker, level = result
            try:
                self.stealable[worker][level].remove(key)
            except KeyError:
                pass
            try:
                self.stealable_all[level].remove(key)
            except KeyError:
                pass

    def steal_time_ratio(self, key, split=None):
        """ The compute to communication time ratio of a key

        Returns
        -------

        cost_multiplier: The increased cost from moving this task as a factor.
        For example a result of zero implies a task without dependencies.
        level: The location within a stealable list to place this value
        """
        if (key not in self.scheduler.loose_restrictions and
            (key in self.scheduler.host_restrictions
             or key in self.scheduler.worker_restrictions)
                or key in self.scheduler.resource_restrictions):
            return None, None  # don't steal

        if not self.scheduler.dependencies[key]:  # no dependencies fast path
            return 0, 0

        nbytes = sum(
            self.scheduler.nbytes.get(k, 1000)
            for k in self.scheduler.dependencies[key])

        transfer_time = nbytes / BANDWIDTH + LATENCY
        split = split or key_split(key)
        if split in fast_tasks:
            return None, None
        try:
            worker = self.scheduler.rprocessing[key]
            compute_time = self.scheduler.processing[worker][key]
        except KeyError:
            self.stealable_unknown_durations[split].add(key)
            return None, None
        else:
            if compute_time < 0.005:  # 5ms, just give up
                return None, None
            cost_multiplier = transfer_time / compute_time
            if cost_multiplier > 100:
                return None, None

            level = int(round(log(cost_multiplier) / log_2 + 6, 0))
            level = max(1, level)
            return cost_multiplier, level

    def move_task(self, key, victim, thief):
        try:
            if self.scheduler.validate:
                if victim != self.scheduler.rprocessing[key]:
                    import pdb
                    pdb.set_trace()

            self.remove_key_from_stealable(key)
            logger.debug("Moved %s, %s: %2f -> %s: %2f", key, victim,
                         self.scheduler.occupancy[victim], thief,
                         self.scheduler.occupancy[thief])

            duration = self.scheduler.processing[victim].pop(key)
            self.scheduler.occupancy[victim] -= duration
            self.scheduler.total_occupancy -= duration

            duration = self.scheduler.task_duration.get(key_split(key), 0.5)
            duration += sum(self.scheduler.nbytes[key]
                            for key in self.scheduler.dependencies[key] -
                            self.scheduler.has_what[thief]) / BANDWIDTH
            self.scheduler.processing[thief][key] = duration
            self.scheduler.rprocessing[key] = thief
            self.scheduler.occupancy[thief] += duration
            self.scheduler.total_occupancy += duration
            self.put_key_in_stealable(key)

            self.scheduler.worker_comms[victim].send({
                'op': 'release-task',
                'reason': 'stolen',
                'key': key
            })

            try:
                self.scheduler.send_task_to_worker(thief, key)
            except CommClosedError:
                self.scheduler.remove_worker(thief)
        except CommClosedError:
            logger.info("Worker comm closed while stealing: %s", victim)
        except Exception as e:
            logger.exception(e)
            if LOG_PDB:
                import pdb
                pdb.set_trace()
            raise

    def balance(self):
        with log_errors():
            i = 0
            s = self.scheduler
            occupancy = s.occupancy
            idle = s.idle
            saturated = s.saturated
            if not idle or len(idle) == len(self.scheduler.workers):
                return

            log = list()
            start = time()

            seen = False
            acted = False

            if not s.saturated:
                saturated = topk(10, s.workers, key=occupancy.get)
                saturated = [
                    w for w in saturated if occupancy[w] > 0.2
                    and len(s.processing[w]) > s.ncores[w]
                ]
            elif len(s.saturated) < 20:
                saturated = sorted(saturated, key=occupancy.get, reverse=True)

            if len(idle) < 20:
                idle = sorted(idle, key=occupancy.get)

            for level, cost_multiplier in enumerate(self.cost_multipliers):
                if not idle:
                    break
                for sat in list(saturated):
                    stealable = self.stealable[sat][level]
                    if not stealable or not idle:
                        continue
                    else:
                        seen = True

                    for key in list(stealable):
                        i += 1
                        if not idle:
                            break
                        idl = idle[i % len(idle)]
                        duration = s.processing[sat][key]

                        if (occupancy[idl] + cost_multiplier * duration <=
                                occupancy[sat] - duration / 2):
                            self.move_task(key, sat, idl)
                            log.append((start, level, key, duration, sat,
                                        occupancy[sat], idl, occupancy[idl]))
                            self.scheduler.check_idle_saturated(sat)
                            self.scheduler.check_idle_saturated(idl)
                            seen = True

                if self.cost_multipliers[
                        level] < 20:  # don't steal from public at cost
                    stealable = self.stealable_all[level]
                    if stealable:
                        seen = True
                    for key in list(stealable):
                        if not idle:
                            break

                        sat = s.rprocessing[key]
                        if occupancy[sat] < 0.2:
                            continue
                        if len(s.processing[sat]) <= s.ncores[sat]:
                            continue

                        i += 1
                        idl = idle[i % len(idle)]
                        duration = s.processing[sat][key]

                        if (occupancy[idl] + cost_multiplier * duration <=
                                occupancy[sat] - duration / 2):
                            self.move_task(key, sat, idl)
                            log.append((start, level, key, duration, sat,
                                        occupancy[sat], idl, occupancy[idl]))
                            self.scheduler.check_idle_saturated(sat)
                            self.scheduler.check_idle_saturated(idl)
                            seen = True

                if seen and not acted:
                    break

            if log:
                self.log.append(log)
                self.count += 1
            stop = time()
            if self.scheduler.digests:
                self.scheduler.digests['steal-duration'].add(stop - start)

    def restart(self, scheduler):
        for stealable in self.stealable.values():
            for s in stealable:
                s.clear()

        for s in self.stealable_all:
            s.clear()
        self.key_stealable.clear()
        self.stealable_unknown_durations.clear()

    def story(self, *keys):
        keys = set(keys)
        return [t for L in self.log for t in L if any(x in keys for x in t)]
Beispiel #45
0
class CrawlProcessHandlerBase(object):
    def __init__(self, game_params, username, logger):
        self.game_params = game_params
        self.username = username
        self.logger = logging.LoggerAdapter(logger, {})
        try:
            self.logger.manager
            self.logger.process = self._process_log_msg
        except AttributeError:
            # This is a workaround for a python 3.5 bug with chained
            # LoggerAdapters, where delegation is not handled properly (e.g.
            # manager isn't set, _log isn't available, etc.). This simple fix
            # only handles two levels of chaining. The general fix is to
            # upgrade to python 3.7.
            # Issue: https://bugs.python.org/issue31457
            self.logger = logging.LoggerAdapter(logger.logger, {})
            self.logger.process = lambda m,k: logger.process(*self._process_log_msg(m, k))

        self.queue_messages = False

        self.process = None
        self.client_path = self.config_path("client_path")
        self.crawl_version = None
        self.where = {}
        self.wheretime = 0
        self.last_milestone = None
        self.kill_timeout = None

        self.muted = set()

        now = datetime.datetime.utcnow()
        self.formatted_time = now.strftime("%Y-%m-%d.%H:%M:%S")
        self.lock_basename = self.formatted_time + ".ttyrec"

        self.end_callback = None
        self._receivers = set()
        self.last_activity_time = time.time()
        self.idle_checker = PeriodicCallback(self.check_idle, 10000)
        self.idle_checker.start()
        self._was_idle = False
        self.last_watcher_join = 0
        self.receiving_direct_milestones = False

        global last_game_id
        self.id = last_game_id + 1
        last_game_id = self.id

    def _process_log_msg(self, msg, kwargs):
        return "P%-5s %s" % (self.id, msg), kwargs

    def format_path(self, path):
        return dgl_format_str(path, self.username, self.game_params)

    def config_path(self, key):
        if key not in self.game_params:
            return None
        base_path = self.format_path(self.game_params[key])
        if key == "socket_path" and config.get('live_debug'):
            # TODO: this is kind of brute-force given that regular paths aren't
            # validated at all...
            debug_path = os.path.join(base_path, 'live-debug')
            if not os.path.isdir(debug_path):
                os.makedirs(debug_path)
            return debug_path
        else:
            return base_path

    def idle_time(self):
        return int(time.time() - self.last_activity_time)

    def is_idle(self):
        return self.idle_time() > 30

    def check_idle(self):
        if self.is_idle() != self._was_idle:
            self._was_idle = self.is_idle()
            if config.get('dgl_mode'):
                update_all_lobbys(self)

    def flush_messages_to_all(self):
        for receiver in self._receivers:
            receiver.flush_messages()

    def write_to_all(self, msg, send): # type: (str, bool) -> None
        for receiver in self._receivers:
            receiver.append_message(msg, send)

    def send_to_all(self, msg, **data): # type: (str, Any) -> None
        for receiver in self._receivers:
            receiver.send_message(msg, **data)

    def chat_help_message(self, source, command, desc):
        if len(command) == 0:
            self.handle_notification_raw(source,
                        "&nbsp;" * 8 + "<span>%s</span>" % (xhtml_escape(desc)))
        else:
            self.handle_notification_raw(source,
                        "&nbsp;" * 4 + "<span>%s: %s</span>" %
                                (xhtml_escape(command), xhtml_escape(desc)))

    def chat_command_help(self, source):
        # TODO: generalize
        # the chat window is basically fixed width, and these are calibrated
        # to not do a linewrap
        self.handle_notification(source, "The following chat commands are available:")
        self.chat_help_message(source, "/help", "show chat command help.")
        self.chat_help_message(source, "/hide", "hide the chat window.")
        if self.is_player(source):
            self.chat_help_message(source, "/mute <name>", "add <name> to the mute list.")
            self.chat_help_message(source, "", "Must be present in channel.")
            self.chat_help_message(source, "/mutelist", "show your entire mute list.")
            self.chat_help_message(source, "/unmute <name>", "remove <name> from the mute list.")
            self.chat_help_message(source, "/unmute *", "clear your mute list.")

    def handle_chat_command(self, source_ws, text):
        # type: (CrawlWebSocket, str) -> bool
        source = source_ws.username
        text = text.strip()
        if len(text) == 0 or text[0] != '/':
            return False
        splitlist = text.split(None, 1)
        if len(splitlist) == 1:
            command = splitlist[0]
            remainder = ""
        else:
            command, remainder = splitlist
        command = command.lower()
        # TODO: generalize
        if command == "/mute":
            self.mute(source, remainder)
        elif command == "/unmute":
            self.unmute(source, remainder)
        elif command == "/mutelist":
            self.show_mute_list(source)
        elif command == "/help":
            self.chat_command_help(source)
        elif command == "/hide":
            self.hide_chat(source_ws, remainder.strip())
        else:
            return False
        return True

    def handle_chat_message(self, username, text): # type: (str, str) -> None
        if username in self.muted: # TODO: message?
            return
        chat_msg = ("<span class='chat_sender'>%s</span>: <span class='chat_msg'>%s</span>" %
                    (username, xhtml_escape(text)))
        self.send_to_all("chat", content = chat_msg)

    def get_receivers_by_username(self, username):
        result = list()
        for w in self._receivers:
            if not w.username:
                continue
            if w.username == username:
                result.append(w)
        return result

    def get_primary_receiver(self):
        # TODO: does this work with console? Probably not...
        if self.username is None:
            return None
        receivers = self.get_receivers_by_username(self.username)
        for r in receivers:
            if not r.watched_game:
                return r
        return None

    def send_to_user(self, username, msg, **data):
        # type: (str, str, Any) -> None
        # a single user may be viewing from multiple receivers
        for receiver in self.get_receivers_by_username(username):
            receiver.send_message(msg, **data)

    # obviously, don't use this for player/spectator-accessible data. But, it
    # is still partially sanitized in chat.js.
    def handle_notification_raw(self, username, text):
        # type: (str, str) -> None
        msg = ("<span class='chat_msg'>%s</span>" % text)
        self.send_to_user(username, "chat", content=msg, meta=True)

    def handle_notification(self, username, text):
        # type: (str, str) -> None
        self.handle_notification_raw(username, xhtml_escape(text))

    def handle_process_end(self):
        if self.kill_timeout:
            IOLoop.current().remove_timeout(self.kill_timeout)
            self.kill_timeout = None

        self.idle_checker.stop()

        # send game_ended message to watchers. The player is handled in cleanup
        # code in ws_handler.py.
        for watcher in list(self._receivers):
            if watcher.watched_game == self:
                watcher.send_message("game_ended", reason = self.exit_reason,
                                     message = self.exit_message,
                                     dump = self.exit_dump_url)
                watcher.go_lobby()

        if self.end_callback:
            self.end_callback()

    def get_watchers(self, chatting_only=False):
        # TODO: I don't understand why this code didn't just use self.username,
        # when will this be different than player_name? Maybe for a console
        # player?
        player_name = None
        watchers = list()
        for w in self._receivers:
            if not w.username: # anon
                continue
            if chatting_only and w.chat_hidden:
                continue
            if not w.watched_game:
                player_name = w.username
            else:
                watchers.append(w.username)
        watchers.sort(key=lambda s:s.lower())
        return (player_name, watchers)

    def is_player(self, username):
        # TODO: probably doesn't work for console players spectating themselves
        # TODO: let admin accounts mute as well?
        player_name, watchers = self.get_watchers()
        return (username == player_name)

    def hide_chat(self, receiver, param):
        if param == "forever":
            receiver.send_message("super_hide_chat")
            receiver.chat_hidden = True # currently only for super hidden chat
            self.update_watcher_description()
        else:
            receiver.send_message("toggle_chat")

    def restore_mutelist(self, source, l):
        if not self.is_player(source) or l is None:
            return
        if len(l) == 0:
            return
        self.muted = {u for u in l if u != source}
        self.handle_notification(source, "Restoring mute list.")
        self.show_mute_list(source)
        self.logger.info("Player '%s' restoring mutelist %s" %
                                            (source, repr(list(self.muted))))

    def save_mutelist(self, source):
        if not self.is_player(source):
            return
        receiver = self.get_primary_receiver()
        if receiver is not None:
            receiver.save_mutelist(list(self.muted))

    def mute(self, source, target):
        if not self.is_player(source):
            self.handle_notification(source,
                            "You do not have permission to mute spectators.")
            return False
        if (source == target):
            self.handle_notification(source, "You can't mute yourself!")
            return False
        player_name, watchers = self.get_watchers()
        watchers = set(watchers)
        if not target in watchers:
            self.handle_notification(source, "Mute who??")
            return False
        self.logger.info("Player '%s' has muted '%s'" % (source, target))
        self.handle_notification(source,
                            "Spectator '%s' has now been muted." % target)
        self.muted |= {target}
        self.save_mutelist(source)
        self.update_watcher_description()
        return True

    def unmute(self, source, target):
        if not self.is_player(source):
            self.handle_notification(source,
                            "You do not have permission to unmute spectators.")
            return False
        if (source == target):
            self.handle_notification(source,
                                    "You can't unmute (or mute) yourself!")
            return False
        if target == "*":
            if (len(self.muted) == 0):
                self.handle_notification(source, "No one is muted!")
                return False
            self.logger.info("Player '%s' has cleared their mute list." % (source))
            self.handle_notification(source, "You have cleared your mute list.")
            self.muted = set()
            self.save_mutelist(source)
            self.update_watcher_description()
            return True

        if not target in self.muted:
            self.handle_notification(source, "Unmute who??")
            return False

        self.logger.info("Player '%s' has unmuted '%s'" % (source, target))
        self.handle_notification(source, "You have unmuted '%s'." % target)
        self.muted -= {target}
        self.save_mutelist(source)
        self.update_watcher_description()
        return True

    def show_mute_list(self, source):
        if not self.is_player(source):
            return False
        names = list(self.muted)
        names.sort(key=lambda s: s.lower())
        if len(names) == 0:
            self.handle_notification(source, "No one is muted.")
        else:
            self.handle_notification(source, "You have muted: " +
                                                            ", ".join(names))
        return True

    def get_anon(self):
        return [w for w in self._receivers if not w.username]

    def update_watcher_description(self):
        player_url_template = config.get('player_url')

        def wrap_name(watcher, is_player=False):
            if is_player:
                class_type = 'player'
            else:
                class_type = 'watcher'
            n = watcher
            if (watcher in self.muted):
                n += " (muted)"
            if player_url_template is None:
                return "<span class='{0}'>{1}</span>".format(class_type, n)
            player_url = player_url_template.replace("%s", watcher.lower())
            username = "******".format(player_url, class_type, n)
            return username

        player_name, watchers = self.get_watchers(True)

        watcher_names = []
        if player_name is not None:
            watcher_names.append(wrap_name(player_name, True))
        watcher_names += [wrap_name(w) for w in watchers]

        anon_count = len(self.get_anon())
        s = ", ".join(watcher_names)
        if len(watcher_names) > 0 and anon_count > 0:
            s = s + " and %i Anon" % anon_count
        elif anon_count > 0:
            s = "%i Anon" % anon_count
        self.send_to_all("update_spectators",
                         count = self.watcher_count(),
                         names = s)

        if config.get('dgl_mode'):
            update_all_lobbys(self)

    def add_watcher(self, watcher):
        self.last_watcher_join = time.time()
        if self.client_path:
            self._send_client(watcher)
            if watcher.watched_game == self:
                watcher.send_json_options(self.game_params["id"], self.username)
        self._receivers.add(watcher)
        self.update_watcher_description()

    def remove_watcher(self, watcher):
        self._receivers.remove(watcher)
        self.update_watcher_description()

    def watcher_count(self):
        return len([w for w in self._receivers if w.watched_game and not w.chat_hidden])

    def send_client_to_all(self):
        for receiver in self._receivers:
            self._send_client(receiver)
            if receiver.watched_game == self:
                receiver.send_json_options(self.game_params["id"],
                                           self.username)

    def _send_client(self, watcher):
        h = hashlib.sha1(utf8(os.path.abspath(self.client_path)))
        if self.crawl_version:
            h.update(utf8(self.crawl_version))
        v = h.hexdigest()
        GameDataHandler.add_version(v,
                                    os.path.join(self.client_path, "static"))

        templ_path = os.path.join(self.client_path, "templates")
        loader = DynamicTemplateLoader.get(templ_path)
        templ = loader.load("game.html")
        game_html = to_unicode(templ.generate(version = v))
        watcher.send_message("game_client", version = v, content = game_html)

    def stop(self):
        if self.process:
            self.process.send_signal(subprocess.signal.SIGHUP)
            t = time.time() + config.get('kill_timeout')
            self.kill_timeout = IOLoop.current().add_timeout(t, self.kill)

    def kill(self):
        if self.process:
            self.logger.info("Killing crawl process after SIGHUP did nothing.")
            self.process.send_signal(subprocess.signal.SIGABRT)
            self.kill_timeout = None

    interesting_info = ("xl", "char", "place", "turn", "dur", "god", "title")

    def set_where_info(self, newwhere):
        # milestone doesn't count as "interesting" but the field is directly
        # handled when sending lobby info by looking at last_milestone
        milestone = bool(newwhere.get("milestone"))
        interesting = (milestone
                or newwhere.get("status") == "chargen"
                or any([self.where.get(key) != newwhere.get(key)
                        for key in CrawlProcessHandlerBase.interesting_info]))

        # ignore milestone sync messages for where purposes
        if newwhere.get("status") != "milestone_only":
            self.where = newwhere
        if milestone:
            self.last_milestone = newwhere
        if interesting:
            update_all_lobbys(self)

    def check_where(self):
        if self.receiving_direct_milestones:
            return
        morgue_path = self.config_path("morgue_path")
        wherefile = os.path.join(morgue_path, self.username + ".where")
        try:
            if os.path.getmtime(wherefile) > self.wheretime:
                self.wheretime = time.time()
                with open(wherefile, "r") as f:
                    wheredata = f.read()

                if wheredata.strip() == "": return

                try:
                    newwhere = parse_where_data(wheredata)
                except:
                    self.logger.warning("Exception while trying to parse where file!",
                                        exc_info=True)
                else:
                    if (newwhere.get("status") == "active" or
                                        newwhere.get("status") == "saved"):
                        self.set_where_info(newwhere)
        except (OSError, IOError):
            pass

    def lobby_entry(self):
        entry = {
            "id": self.id,
            "username": self.username,
            "spectator_count": self.watcher_count(),
            "idle_time": (self.idle_time() if self.is_idle() else 0),
            "game_id": self.game_params["id"],
            }
        for key in CrawlProcessHandlerBase.interesting_info:
            if key in self.where:
                entry[key] = self.where[key]
        if self.last_milestone and self.last_milestone.get("milestone"):
            entry["milestone"] = self.last_milestone.get("milestone")
        return entry

    def human_readable_where(self):
        try:
            return "L{xl} {char}, {place}".format(**self.where)
        except KeyError:
            return ""

    def _base_call(self):
        game = self.game_params


        call  = [game["crawl_binary"]]

        if "pre_options" in game:
            call += game["pre_options"]

        call += ["-name",   self.username,
                 "-rc",     os.path.join(self.config_path("rcfile_path"),
                                         self.username + ".rc"),
                 "-macro",  os.path.join(self.config_path("macro_path"),
                                         self.username + ".macro"),
                 "-morgue", self.config_path("morgue_path")]

        if "options" in game:
            call += game["options"]

        if "dir_path" in game:
            call += ["-dir", self.config_path("dir_path")]

        return call

    def note_activity(self):
        self.last_activity_time = time.time()
        self.check_idle()

    def handle_input(self, msg):
        raise NotImplementedError()
Beispiel #46
0
class PollingHandler(BaseSocketHandler):
    """
    This class represents separate websocket connection.
    
    Attributes:
        tracker: tornado.ioloop.PeriodicCallback with get_location method as a
            callback. Starts when user pushes "track" button. When started, it
            runs every 5 seconds to find out and update character's location.
            
        q: tornado.queues.Queue used for running tasks successively.
        
        updating: A flag indicates if router is being updated or not. Required
            to avoid race conditions.
    """
    def __init__(self, *args, **kwargs):
        super(PollingHandler, self).__init__(*args, **kwargs)
        # Set Tornado PeriodicCallback with our self.track, we
        # will use launch it later on track/untrack commands
        self.tracker = PeriodicCallback(self.get_location, 5000)
        self.q = Queue(maxsize=5)
        self.updating = False

    async def get_location(self):
        """
        The callback for the `self.tracker`.
        
        Makes an API call, updates router and sends updated data to the
        front-end.
        """
        # Call API to find out current character location
        location = await self.character(self.user_id, '/location/', 'GET')

        if location:
            # Set `updating` flag to not accept periodic updates
            # from front-end, to not overwrite new data
            self.updating = True
            user = self.user
            graph_data = await user['router'].update(
                location['solarSystem']['name'])
            if graph_data:
                message = ['update', graph_data]
                logging.warning(graph_data)
                await self.safe_write(message)
            self.updating = False
        else:
            message = ['warning', 'Log into game to track your route']
            await self.safe_write(message)

    async def scheduler(self):
        """
        Scheduler for user tasks.
        
        Waits until there is new item in the queue, does task, resolves task.
        Tornado queues doc: http://www.tornadoweb.org/en/stable/queues.html
        
        Since we have no guarantee of the order of the incoming messages
        (new message from front-end can come before current is done),
        we need to ensure all tasks to run successively.
        Here comes the asynchronous generator.
        """
        logging.info(f"Scheduler started for {self.request.remote_ip}")

        # Wait on each iteration until there's actually an item available
        async for item in self.q:
            logging.debug(f"Started resolving task for {item}...")
            user = self.user
            try:
                if item == 'recover':
                    # Send saved route
                    await self.safe_write(['recover', user['router'].recovery])

                elif item == 'track':
                    # Start the PeriodicCallback
                    if not self.tracker.is_running():
                        self.tracker.start()

                elif item in ['stop', 'reset']:
                    # Stop the PeriodicCallback
                    if self.tracker.is_running():
                        self.tracker.stop()
                    # Clear all saved data
                    if item == 'reset':
                        await user['router'].reset()

                elif item[0] == 'backup':
                    # Do not overwrite user object while it's updating,
                    # just in case, to avoid race conditions.
                    if not self.updating:
                        await user['router'].backup(item[1])
            finally:
                self.q.task_done()
                logging.debug(f'Task "{item}" done.')

    async def task(self, item):
        """
        Intermediary between `self.on_message` and `self.scheduler`.
        
        Since we cannot do anything asynchronous in the `self.on_message`,
        this method can handle any additional non-blocking stuff if we need it.
        
        :argument item: item to pass to the `self.scheduler`.
        """
        await self.q.put(item)
        #await self.q.join()

    def open(self):
        """
        Triggers on successful websocket connection.
        
        Ensures user is authorized, spawns `self.scheduler` for user
        tasks, adds this websocket object to the connections pool,
        spawns the recovery of the saved route.
        """
        logging.info(f"Connection received from {self.request.remote_ip}")
        if self.user_id:
            self.spawn(self.scheduler)
            self.vagrants.append(self)

            self.spawn(self.task, 'recover')
        else:
            self.close()

    def on_message(self, message):
        """
        Triggers on receiving front-end message.
        
        :argument message: front-end message.
        
        Receives user commands and passes them
        to the `self.scheduler` via `self.task`.
        """
        self.spawn(self.task, json_decode(message))

    def on_close(self):
        """
        Triggers on closed websocket connection.
        
        Removes this websocket object from the connections pool,
        stops `self.tracker` if it is running.
        """
        self.vagrants.remove(self)
        if self.tracker.is_running():
            self.tracker.stop()
        logging.info("Connection closed, " + self.request.remote_ip)
Beispiel #47
0
class PeriodicCallback(param.Parameterized):
    """
    Periodic encapsulates a periodic callback which will run both
    in tornado based notebook environments and on bokeh server. By
    default the callback will run until the stop method is called,
    but count and timeout values can be set to limit the number of
    executions or the maximum length of time for which the callback
    will run.
    """

    callback = param.Callable(doc="""
       The callback to execute periodically.""")

    count = param.Integer(default=None, doc="""
        Number of times the callback will be executed, by default
        this is unlimited.""")

    period = param.Integer(default=500, doc="""
        Period in milliseconds at which the callback is executed.""")

    timeout = param.Integer(default=None, doc="""
        Timeout in seconds from the start time at which the callback
        expires""")

    def __init__(self, **params):
        super(PeriodicCallback, self).__init__(**params)
        self._counter = 0
        self._start_time = None
        self._timeout = None
        self._cb = None
        self._doc = None

    def start(self):
        if self._cb is not None:
            raise RuntimeError('Periodic callback has already started.')
        self._start_time = time.time()
        if _curdoc().session_context:
            self._doc = _curdoc()
            self._cb = self._doc.add_periodic_callback(self._periodic_callback, self.period)
        else:
            from tornado.ioloop import PeriodicCallback
            self._cb = PeriodicCallback(self._periodic_callback, self.period)
            self._cb.start()

    def _periodic_callback(self):
        self.callback()
        self._counter += 1
        if self._timeout is not None:
            dt = (time.time() - self._start_time)
            if dt > self._timeout:
                self.stop()
        if self._counter == self.count:
            self.stop()

    def stop(self):
        self._counter = 0
        self._timeout = None
        if self._doc:
            self._doc.remove_periodic_callback(self._cb)
        else:
            self._cb.stop()
        self._cb = None
Beispiel #48
0
class BokehTornado(TornadoApplication):
    ''' A Tornado Application used to implement the Bokeh Server.

    Args:
        applications (dict[str,Application] or Application) :
            A map from paths to ``Application`` instances.

            If the value is a single Application, then the following mapping
            is generated:

            .. code-block:: python

                applications = {{ '/' : applications }}

            When a connection comes in to a given path, the associate
            Application is used to generate a new document for the session.

        prefix (str, optional) :
            A URL prefix to use for all Bokeh server paths. (default: None)

        extra_websocket_origins (list[str], optional) :
            A list of hosts that can connect to the websocket.

            This is typically required when embedding a Bokeh server app in an
            external web site using :func:`~bokeh.embed.server_document` or
            similar.

            If None, ``["localhost"]`` will be assumed (default: None)

        extra_patterns (seq[tuple], optional) :
            A list of tuples of (str, http or websocket handler)

            Use this argument to add additional endpoints to custom deployments
            of the Bokeh Server.

            If None, then ``[]`` will be used. (default: None)

        secret_key (str, optional) :
            A secret key for signing session IDs.

            Defaults to the current value of the environment variable
            ``BOKEH_SECRET_KEY``

        sign_sessions (bool, optional) :
            Whether to cryptographically sign session IDs

            Defaults to the current value of the environment variable
            ``BOKEH_SIGN_SESSIONS``. If ``True``, then ``secret_key`` must
            also be provided (either via environment setting or passed as
            a parameter value)

        generate_session_ids (bool, optional) :
            Whether to generate a session ID if one is not provided
            (default: True)

        keep_alive_milliseconds (int, optional) :
            Number of milliseconds between keep-alive pings
            (default: {DEFAULT_KEEP_ALIVE_MS})

            Pings normally required to keep the websocket open. Set to 0 to
            disable pings.

        check_unused_sessions_milliseconds (int, optional) :
            Number of milliseconds between checking for unused sessions
            (default: {DEFAULT_CHECK_UNUSED_MS})

        unused_session_lifetime_milliseconds (int, optional) :
            Number of milliseconds for unused session lifetime
            (default: {DEFAULT_UNUSED_LIFETIME_MS})

        stats_log_frequency_milliseconds (int, optional) :
            Number of milliseconds between logging stats
            (default: {DEFAULT_STATS_LOG_FREQ_MS})

        mem_log_frequency_milliseconds (int, optional) :
            Number of milliseconds between logging memory information
            (default: {DEFAULT_MEM_LOG_FREQ_MS})

            Enabling this feature requires the optional dependency ``psutil`` to be
            installed.

        use_index (bool, optional) :
            Whether to generate an index of running apps in the ``RootHandler``
            (default: True)

        index (str, optional) :
            Path to a Jinja2 template to serve as the index for "/" if use_index
            is True. If None, the basic built in app index template is used.
            (default: None)

        redirect_root (bool, optional) :
            When there is only a single running application, whether to
            redirect requests to ``"/"`` to that application automatically
            (default: True)

            If there are multiple Bokeh applications configured, this option
            has no effect.

        websocket_max_message_size_bytes (int, optional):
            Set the Tornado ``websocket_max_message_size`` value.
            (default: {DEFAULT_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES})

        index (str, optional):
            Path to a Jinja2 template to use for the root URL

        auth_provider (AuthProvider, optional):
            An AuthProvider instance

    Any additional keyword arguments are passed to ``tornado.web.Application``.
    '''

    def __init__(self,
                 applications,
                 prefix=None,
                 extra_websocket_origins=None,
                 extra_patterns=None,
                 secret_key=settings.secret_key_bytes(),
                 sign_sessions=settings.sign_sessions(),
                 generate_session_ids=True,
                 keep_alive_milliseconds=DEFAULT_KEEP_ALIVE_MS,
                 check_unused_sessions_milliseconds=DEFAULT_CHECK_UNUSED_MS,
                 unused_session_lifetime_milliseconds=DEFAULT_UNUSED_LIFETIME_MS,
                 stats_log_frequency_milliseconds=DEFAULT_STATS_LOG_FREQ_MS,
                 mem_log_frequency_milliseconds=DEFAULT_MEM_LOG_FREQ_MS,
                 use_index=True,
                 redirect_root=True,
                 websocket_max_message_size_bytes=DEFAULT_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES,
                 index=None,
                 auth_provider=NullAuth(),
                 xsrf_cookies=False,
                 **kwargs):

        # This will be set when initialize is called
        self._loop = None

        if isinstance(applications, Application):
            applications = { '/' : applications }

        if prefix is None:
            prefix = ""
        prefix = prefix.strip("/")
        if prefix:
            prefix = "/" + prefix

        self._prefix = prefix

        self._index = index

        if keep_alive_milliseconds < 0:
            # 0 means "disable"
            raise ValueError("keep_alive_milliseconds must be >= 0")
        else:
            if keep_alive_milliseconds == 0:
                log.info("Keep-alive ping disabled")
            elif keep_alive_milliseconds != DEFAULT_KEEP_ALIVE_MS:
                log.info("Keep-alive ping configured every %d milliseconds", keep_alive_milliseconds)
        self._keep_alive_milliseconds = keep_alive_milliseconds

        if check_unused_sessions_milliseconds <= 0:
            raise ValueError("check_unused_sessions_milliseconds must be > 0")
        elif check_unused_sessions_milliseconds != DEFAULT_CHECK_UNUSED_MS:
            log.info("Check for unused sessions every %d milliseconds", check_unused_sessions_milliseconds)
        self._check_unused_sessions_milliseconds = check_unused_sessions_milliseconds

        if unused_session_lifetime_milliseconds <= 0:
            raise ValueError("check_unused_sessions_milliseconds must be > 0")
        elif unused_session_lifetime_milliseconds != DEFAULT_UNUSED_LIFETIME_MS:
            log.info("Unused sessions last for %d milliseconds", unused_session_lifetime_milliseconds)
        self._unused_session_lifetime_milliseconds = unused_session_lifetime_milliseconds

        if stats_log_frequency_milliseconds <= 0:
            raise ValueError("stats_log_frequency_milliseconds must be > 0")
        elif stats_log_frequency_milliseconds != DEFAULT_STATS_LOG_FREQ_MS:
            log.info("Log statistics every %d milliseconds", stats_log_frequency_milliseconds)
        self._stats_log_frequency_milliseconds = stats_log_frequency_milliseconds

        if mem_log_frequency_milliseconds < 0:
            # 0 means "disable"
            raise ValueError("mem_log_frequency_milliseconds must be >= 0")
        elif mem_log_frequency_milliseconds > 0:
            if import_optional('psutil') is None:
                log.warning("Memory logging requested, but is disabled. Optional dependency 'psutil' is missing. "
                         "Try 'pip install psutil' or 'conda install psutil'")
                mem_log_frequency_milliseconds = 0
            elif mem_log_frequency_milliseconds != DEFAULT_MEM_LOG_FREQ_MS:
                log.info("Log memory usage every %d milliseconds", mem_log_frequency_milliseconds)
        self._mem_log_frequency_milliseconds = mem_log_frequency_milliseconds

        if websocket_max_message_size_bytes <= 0:
            raise ValueError("websocket_max_message_size_bytes must be positive")
        elif websocket_max_message_size_bytes != DEFAULT_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES:
            log.info("Torndado websocket_max_message_size set to %d bytes (%0.2f MB)",
                     websocket_max_message_size_bytes,
                     websocket_max_message_size_bytes/1024.0**2)

        self.auth_provider = auth_provider

        if self.auth_provider.get_user or self.auth_provider.get_user_async:
            log.info("User authentication hooks provided (no default user)")
        else:
            log.info("User authentication hooks NOT provided (default user enabled)")

        kwargs['xsrf_cookies'] = xsrf_cookies
        if xsrf_cookies:
            log.info("XSRF cookie protection enabled")

        if extra_websocket_origins is None:
            self._websocket_origins = set()
        else:
            self._websocket_origins = set(extra_websocket_origins)
        self._secret_key = secret_key
        self._sign_sessions = sign_sessions
        self._generate_session_ids = generate_session_ids
        log.debug("These host origins can connect to the websocket: %r", list(self._websocket_origins))

        # Wrap applications in ApplicationContext
        self._applications = dict()
        for k,v in applications.items():
            self._applications[k] = ApplicationContext(v, url=k, logout_url=self.auth_provider.logout_url)

        extra_patterns = extra_patterns or []
        extra_patterns.extend(self.auth_provider.endpoints)

        all_patterns = []
        for key, app in applications.items():
            app_patterns = []
            for p in per_app_patterns:
                if key == "/":
                    route = p[0]
                else:
                    route = key + p[0]
                route = self._prefix + route
                app_patterns.append((route, p[1], { "application_context" : self._applications[key] }))

            websocket_path = None
            for r in app_patterns:
                if r[0].endswith("/ws"):
                    websocket_path = r[0]
            if not websocket_path:
                raise RuntimeError("Couldn't find websocket path")
            for r in app_patterns:
                r[2]["bokeh_websocket_path"] = websocket_path

            all_patterns.extend(app_patterns)

            # add a per-app static path if requested by the application
            if app.static_path is not None:
                if key == "/":
                    route = "/static/(.*)"
                else:
                    route = key + "/static/(.*)"
                route = self._prefix + route
                all_patterns.append((route, StaticFileHandler, { "path" : app.static_path }))

        for p in extra_patterns + toplevel_patterns:
            if p[1] == RootHandler:
                if use_index:
                    data = {"applications": self._applications,
                            "prefix": self._prefix,
                            "index": self._index,
                            "use_redirect": redirect_root}
                    prefixed_pat = (self._prefix + p[0],) + p[1:] + (data,)
                    all_patterns.append(prefixed_pat)
            else:
                prefixed_pat = (self._prefix + p[0],) + p[1:]
                all_patterns.append(prefixed_pat)

        log.debug("Patterns are:")
        for line in pformat(all_patterns, width=60).split("\n"):
            log.debug("  " + line)

        super().__init__(all_patterns,
                                           websocket_max_message_size=websocket_max_message_size_bytes,
                                           **kwargs)

    def initialize(self, io_loop):
        ''' Start a Bokeh Server Tornado Application on a given Tornado IOLoop.

        '''
        self._loop = io_loop

        for app_context in self._applications.values():
            app_context._loop = self._loop

        self._clients = set()

        self._stats_job = PeriodicCallback(self._log_stats,
                                           self._stats_log_frequency_milliseconds)

        if self._mem_log_frequency_milliseconds > 0:
            self._mem_job = PeriodicCallback(self._log_mem,
                                             self._mem_log_frequency_milliseconds)
        else:
            self._mem_job = None


        self._cleanup_job = PeriodicCallback(self._cleanup_sessions,
                                             self._check_unused_sessions_milliseconds)

        if self._keep_alive_milliseconds > 0:
            self._ping_job = PeriodicCallback(self._keep_alive, self._keep_alive_milliseconds)
        else:
            self._ping_job = None

    @property
    def app_paths(self):
        ''' A list of all application paths for all Bokeh applications
        configured on this Bokeh server instance.

        '''
        return set(self._applications)

    @property
    def index(self):
        ''' Path to a Jinja2 template to serve as the index "/"

        '''
        return self._index

    @property
    def io_loop(self):
        ''' The Tornado IOLoop that this  Bokeh Server Tornado Application
        is running on.

        '''
        return self._loop

    @property
    def prefix(self):
        ''' A URL prefix for this Bokeh Server Tornado Application to use
        for all paths

        '''
        return self._prefix

    @property
    def websocket_origins(self):
        ''' A list of websocket origins permitted to connect to this server.

        '''
        return self._websocket_origins

    @property
    def secret_key(self):
        ''' A secret key for this Bokeh Server Tornado Application to use when
        signing session IDs, if configured.

        '''
        return self._secret_key

    @property
    def sign_sessions(self):
        ''' Whether this Bokeh Server Tornado Application has been configured
        to cryptographically sign session IDs

        If ``True``, then ``secret_key`` must also have been configured.
        '''
        return self._sign_sessions

    @property
    def generate_session_ids(self):
        ''' Whether this Bokeh Server Tornado Application has been configured
        to automatically generate session IDs.

        '''
        return self._generate_session_ids

    def resources(self, absolute_url=None):
        ''' Provide a :class:`~bokeh.resources.Resources` that specifies where
        Bokeh application sessions should load BokehJS resources from.

        Args:
            absolute_url (bool):
                An absolute URL prefix to use for locating resources. If None,
                relative URLs are used (default: None)

        '''
        mode = settings.resources(default="server")
        if mode == "server":
            root_url = urljoin(absolute_url, self._prefix) if absolute_url else self._prefix
            return Resources(mode="server", root_url=root_url, path_versioner=StaticHandler.append_version)
        return Resources(mode=mode)

    def start(self):
        ''' Start the Bokeh Server application.

        Starting the Bokeh Server Tornado application will run periodic
        callbacks for stats logging, cleanup, pinging, etc. Additionally, any
        startup hooks defined by the configured Bokeh applications will be run.

        '''
        self._stats_job.start()
        if self._mem_job is not None:
            self._mem_job.start()
        self._cleanup_job.start()
        if self._ping_job is not None:
            self._ping_job.start()

        for context in self._applications.values():
            self._loop.spawn_callback(context.run_load_hook)

    def stop(self, wait=True):
        ''' Stop the Bokeh Server application.

        Args:
            wait (bool): whether to wait for orderly cleanup (default: True)

        Returns:
            None

        '''

        # TODO should probably close all connections and shut down all sessions here
        for context in self._applications.values():
            context.run_unload_hook()

        self._stats_job.stop()
        if self._mem_job is not None:
            self._mem_job.stop()
        self._cleanup_job.stop()
        if self._ping_job is not None:
            self._ping_job.stop()

        self._clients.clear()

    def new_connection(self, protocol, socket, application_context, session):
        connection = ServerConnection(protocol, socket, application_context, session)
        self._clients.add(connection)
        return connection

    def client_lost(self, connection):
        self._clients.discard(connection)
        connection.detach_session()

    def get_session(self, app_path, session_id):
        ''' Get an active a session by name application path and session ID.

        Args:
            app_path (str) :
                The configured application path for the application to return
                a session for.

            session_id (str) :
                The session ID of the session to retrieve.

        Returns:
            ServerSession

        '''
        if app_path not in self._applications:
            raise ValueError("Application %s does not exist on this server" % app_path)
        return self._applications[app_path].get_session(session_id)

    def get_sessions(self, app_path):
        ''' Gets all currently active sessions for an application.

        Args:
            app_path (str) :
                The configured application path for the application to return
                sessions for.

        Returns:
            list[ServerSession]

        '''
        if app_path not in self._applications:
            raise ValueError("Application %s does not exist on this server" % app_path)
        return list(self._applications[app_path].sessions)

    # Periodic Callbacks ------------------------------------------------------

    async def _cleanup_sessions(self):
        log.trace("Running session cleanup job")
        for app in self._applications.values():
            await app._cleanup_sessions(self._unused_session_lifetime_milliseconds)
        return None

    def _log_stats(self):
        log.trace("Running stats log job")

        if log.getEffectiveLevel() > logging.DEBUG:
            # avoid the work below if we aren't going to log anything
            return

        log.debug("[pid %d] %d clients connected", os.getpid(), len(self._clients))
        for app_path, app in self._applications.items():
            sessions = list(app.sessions)
            unused_count = 0
            for s in sessions:
                if s.connection_count == 0:
                    unused_count += 1
            log.debug("[pid %d]   %s has %d sessions with %d unused",
                      os.getpid(), app_path, len(sessions), unused_count)

    def _log_mem(self):
        import psutil

        process = psutil.Process(os.getpid())
        log.info("[pid %d] Memory usage: %0.2f MB (RSS), %0.2f MB (VMS)", os.getpid(), process.memory_info().rss//2**20, process.memory_info().vms//2**20)

        if log.getEffectiveLevel() > logging.DEBUG:
            # avoid the work below if we aren't going to log anything else
            return

        import gc
        from ..document import Document
        from ..model import Model
        from .session import ServerSession
        for name, typ in [('Documents', Document), ('Sessions', ServerSession), ('Models', Model)]:
            objs = [x for x in gc.get_objects() if isinstance(x, typ)]
            log.debug("  uncollected %s: %d", name, len(objs))

    def _keep_alive(self):
        log.trace("Running keep alive job")
        for c in self._clients:
            c.send_ping()
Beispiel #49
0
class EventHandler:
    events_enable_interval = 5000  # in seconds
    # Maximum number of finished items to keep track of
    max_finished_history = 1000
    # celery events that represent a task finishing
    finished_events = (
        'task-succeeded',
        'task-failed',
        'task-rejected',
        'task-revoked',
    )

    def __init__(self, capp, io_loop):
        """Monitors events that are received from celery.

        capp - The celery app
        io_loop - The event loop to use for dispatch
        """
        super().__init__()

        self.capp = capp
        self.timer = PeriodicCallback(self.on_enable_events,
                                      self.events_enable_interval)

        self.monitor = EventMonitor(self.capp, io_loop)
        self.listeners = {}
        self.finished_tasks = LRUCache(self.max_finished_history)

    @tornado.gen.coroutine
    def start(self):
        """Start event handler.

        Expects to be run as a coroutine.
        """
        self.timer.start()
        logger.debug('Starting celery monitor thread')
        self.monitor.start()

        while True:
            event = yield self.monitor.events.get()
            try:
                task_id = event['uuid']
            except KeyError:
                continue

            # Record finished tasks in-case they are requested
            # too late or are re-requested.
            if event['type'] in self.finished_events:
                self.finished_tasks[task_id] = event

            try:
                callback = self.listeners[task_id]
            except KeyError:
                pass
            else:
                callback(event)

    def stop(self):
        self.timer.stop()
        # FIXME: can not be stopped gracefully
        # self.monitor.stop()

    def on_enable_events(self):
        """Called periodically to enable events for workers
        launched after the monitor.
        """
        try:
            self.capp.control.enable_events()
        except Exception as e:
            logger.debug('Failed to enable events: %s', e)

    def add_listener(self, task_id, callback):
        """Add event listener for a task with ID `task_id`."""
        try:
            event = self.finished_tasks[task_id]
        except KeyError:
            self.listeners[task_id] = callback
        else:
            # Task has already finished
            callback(event)

    def remove_listener(self, task_id):
        """Remove listener for `task_id`."""
        try:
            del self.listeners[task_id]
        except KeyError:  # may have been cached
            pass
Beispiel #50
0
class BokehTornado(TornadoApplication):
    ''' A Tornado Application used to implement the Bokeh Server.

        The Server class is the main public interface, this class has
        Tornado implementation details.

    Args:
        applications (dict of str : bokeh.application.Application) : map from paths to Application instances
            The application is used to create documents for each session.
        extra_patterns (seq[tuple]) : tuples of (str, http or websocket handler)
            Use this argument to add additional endpoints to custom deployments
            of the Bokeh Server.
        prefix (str) : a URL prefix to use for all Bokeh server paths
        hosts (list) : hosts that are valid values for the Host header
        secret_key (str) : secret key for signing session IDs
        sign_sessions (boolean) : whether to sign session IDs
        generate_session_ids (boolean) : whether to generate a session ID when none is provided
        extra_websocket_origins (list) : hosts that can connect to the websocket
            These are in addition to ``hosts``.
        keep_alive_milliseconds (int) : number of milliseconds between keep-alive pings
            Set to 0 to disable pings. Pings keep the websocket open.
        develop (boolean) : True for develop mode

    '''

    def __init__(self, applications, prefix, hosts,
                 extra_websocket_origins,
                 io_loop=None,
                 extra_patterns=None,
                 secret_key=settings.secret_key_bytes(),
                 sign_sessions=settings.sign_sessions(),
                 generate_session_ids=True,
                 # heroku, nginx default to 60s timeout, so well less than that
                 keep_alive_milliseconds=37000,
                 # how often to check for unused sessions
                 check_unused_sessions_milliseconds=17000,
                 # how long unused sessions last
                 unused_session_lifetime_milliseconds=60*30*1000,
                 # how often to log stats
                 stats_log_frequency_milliseconds=15000,
                 develop=False):

        self._prefix = prefix

        if io_loop is None:
            io_loop = IOLoop.current()
        self._loop = io_loop

        if keep_alive_milliseconds < 0:
            # 0 means "disable"
            raise ValueError("keep_alive_milliseconds must be >= 0")

        self._hosts = set(hosts)
        self._websocket_origins = self._hosts | set(extra_websocket_origins)
        self._resources = {}
        self._develop = develop
        self._secret_key = secret_key
        self._sign_sessions = sign_sessions
        self._generate_session_ids = generate_session_ids

        log.debug("Allowed Host headers: %r", list(self._hosts))
        log.debug("These host origins can connect to the websocket: %r", list(self._websocket_origins))

        # Wrap applications in ApplicationContext
        self._applications = dict()
        for k,v in applications.items():
            self._applications[k] = ApplicationContext(v, self._develop, self._loop)

        extra_patterns = extra_patterns or []
        all_patterns = []
        for key in applications:
            app_patterns = []
            for p in per_app_patterns:
                if key == "/":
                    route = p[0]
                else:
                    route = key + p[0]
                route = self._prefix + route
                app_patterns.append((route, p[1], { "application_context" : self._applications[key] }))

            websocket_path = None
            for r in app_patterns:
                if r[0].endswith("/ws"):
                    websocket_path = r[0]
            if not websocket_path:
                raise RuntimeError("Couldn't find websocket path")
            for r in app_patterns:
                r[2]["bokeh_websocket_path"] = websocket_path

            all_patterns.extend(app_patterns)

        for p in extra_patterns + toplevel_patterns:
            prefixed_pat = (self._prefix+p[0],) + p[1:]
            all_patterns.append(prefixed_pat)

        for pat in all_patterns:
            _whitelist(pat[1])

        log.debug("Patterns are: %r", all_patterns)

        super(BokehTornado, self).__init__(all_patterns)

        self._clients = set()
        self._executor = ProcessPoolExecutor(max_workers=4)
        self._loop.add_callback(self._start_async)
        self._stats_job = PeriodicCallback(self.log_stats,
                                           stats_log_frequency_milliseconds,
                                           io_loop=self._loop)
        self._unused_session_linger_seconds = unused_session_lifetime_milliseconds
        self._cleanup_job = PeriodicCallback(self.cleanup_sessions,
                                             check_unused_sessions_milliseconds,
                                             io_loop=self._loop)

        if keep_alive_milliseconds > 0:
            self._ping_job = PeriodicCallback(self.keep_alive, keep_alive_milliseconds, io_loop=self._loop)
        else:
            self._ping_job = None

    @property
    def io_loop(self):
        return self._loop

    @property
    def websocket_origins(self):
        return self._websocket_origins

    @property
    def secret_key(self):
        return self._secret_key

    @property
    def sign_sessions(self):
        return self._sign_sessions

    @property
    def generate_session_ids(self):
        return self._generate_session_ids

    def root_url_for_request(self, request):
        return request.protocol + "://" + request.host + self._prefix + "/"

    def websocket_url_for_request(self, request, websocket_path):
        # websocket_path comes from the handler, and already has any
        # prefix included, no need to add here
        protocol = "ws"
        if request.protocol == "https":
            protocol = "wss"
        return protocol + "://" + request.host + websocket_path

    def resources(self, request):
        root_url = self.root_url_for_request(request)
        if root_url not in self._resources:
            self._resources[root_url] =  Resources(mode="server",
                                                   root_url=root_url,
                                                   path_versioner=StaticHandler.append_version)
        return self._resources[root_url]

    def start(self, start_loop=True):
        ''' Start the Bokeh Server application main loop.

        Args:
            start_loop (boolean): False to not actually start event loop, used in tests

        Returns:
            None

        Notes:
            Keyboard interrupts or sigterm will cause the server to shut down.

        '''
        self._stats_job.start()
        self._cleanup_job.start()
        if self._ping_job is not None:
            self._ping_job.start()

        for context in self._applications.values():
            context.run_load_hook()

        if start_loop:
            try:
                self._loop.start()
            except KeyboardInterrupt:
                print("\nInterrupted, shutting down")

    def stop(self):
        ''' Stop the Bokeh Server application.

        Returns:
            None

        '''
        # TODO we should probably close all connections and shut
        # down all sessions either here or in unlisten() ... but
        # it isn't that important since in real life it's rare to
        # do a clean shutdown (vs. a kill-by-signal) anyhow.

        for context in self._applications.values():
            context.run_unload_hook()

        self._stats_job.stop()
        self._cleanup_job.stop()
        if self._ping_job is not None:
            self._ping_job.stop()

        self._loop.stop()

    @property
    def executor(self):
        return self._executor

    def new_connection(self, protocol, socket, application_context, session):
        connection = ServerConnection(protocol, socket, application_context, session)
        self._clients.add(connection)
        return connection

    def client_lost(self, connection):
        self._clients.discard(connection)
        connection.detach_session()

    def get_session(self, app_path, session_id):
        if app_path not in self._applications:
            raise ValueError("Application %s does not exist on this server" % app_path)
        return self._applications[app_path].get_session(session_id)

    def get_sessions(self, app_path):
        if app_path not in self._applications:
            raise ValueError("Application %s does not exist on this server" % app_path)
        return list(self._applications[app_path].sessions)

    @gen.coroutine
    def cleanup_sessions(self):
        for app in self._applications.values():
            yield app.cleanup_sessions(self._unused_session_linger_seconds)
        raise gen.Return(None)

    def log_stats(self):
        if log.getEffectiveLevel() > logging.DEBUG:
            # avoid the work below if we aren't going to log anything
            return
        log.debug("[pid %d] %d clients connected", os.getpid(), len(self._clients))
        for app_path, app in self._applications.items():
            sessions = list(app.sessions)
            unused_count = 0
            for s in sessions:
                if s.connection_count == 0:
                    unused_count += 1
            log.debug("[pid %d]   %s has %d sessions with %d unused",
                      os.getpid(), app_path, len(sessions), unused_count)

    def keep_alive(self):
        for c in self._clients:
            c.send_ping()

    @gen.coroutine
    def run_in_background(self, _func, *args, **kwargs):
        """
        Run a synchronous function in the background without disrupting
        the main thread. Useful for long-running jobs.
        """
        res = yield self._executor.submit(_func, *args, **kwargs)
        raise gen.Return(res)

    @gen.coroutine
    def _start_async(self):
        try:
            atexit.register(self._atexit)
            signal.signal(signal.SIGTERM, self._sigterm)
        except Exception:
            self.exit(1)

    _atexit_ran = False
    def _atexit(self):
        if self._atexit_ran:
            return
        self._atexit_ran = True

        self._stats_job.stop()
        IOLoop.clear_current()
        loop = IOLoop()
        loop.make_current()
        loop.run_sync(self._cleanup)

    def _sigterm(self, signum, frame):
        print("Received SIGTERM, shutting down")
        self.stop()
        self._atexit()

    @gen.coroutine
    def _cleanup(self):
        log.debug("Shutdown: cleaning up")
        self._executor.shutdown(wait=False)
        self._clients.clear()
Beispiel #51
0
class WebSocketProtocol13(WebSocketProtocol):
    """Implementation of the WebSocket protocol from RFC 6455.

    This class supports versions 7 and 8 of the protocol in addition to the
    final version 13.
    """
    # Bit masks for the first byte of a frame.
    FIN = 0x80
    RSV1 = 0x40
    RSV2 = 0x20
    RSV3 = 0x10
    RSV_MASK = RSV1 | RSV2 | RSV3
    OPCODE_MASK = 0x0f

    def __init__(self, handler, mask_outgoing=False,
                 compression_options=None):
        WebSocketProtocol.__init__(self, handler)
        self.mask_outgoing = mask_outgoing
        self._final_frame = False
        self._frame_opcode = None
        self._masked_frame = None
        self._frame_mask = None
        self._frame_length = None
        self._fragmented_message_buffer = None
        self._fragmented_message_opcode = None
        self._waiting = None
        self._compression_options = compression_options
        self._decompressor = None
        self._compressor = None
        self._frame_compressed = None
        # The total uncompressed size of all messages received or sent.
        # Unicode messages are encoded to utf8.
        # Only for testing; subject to change.
        self._message_bytes_in = 0
        self._message_bytes_out = 0
        # The total size of all packets received or sent.  Includes
        # the effect of compression, frame overhead, and control frames.
        self._wire_bytes_in = 0
        self._wire_bytes_out = 0
        self.ping_callback = None
        self.last_ping = 0
        self.last_pong = 0

    def accept_connection(self):
        try:
            self._handle_websocket_headers()
        except ValueError:
            self.handler.set_status(400)
            log_msg = "Missing/Invalid WebSocket headers"
            self.handler.finish(log_msg)
            gen_log.debug(log_msg)
            return

        try:
            self._accept_connection()
        except ValueError:
            gen_log.debug("Malformed WebSocket request received",
                          exc_info=True)
            self._abort()
            return

    def _handle_websocket_headers(self):
        """Verifies all invariant- and required headers

        If a header is missing or have an incorrect value ValueError will be
        raised
        """
        fields = ("Host", "Sec-Websocket-Key", "Sec-Websocket-Version")
        if not all(map(lambda f: self.request.headers.get(f), fields)):
            raise ValueError("Missing/Invalid WebSocket headers")

    @staticmethod
    def compute_accept_value(key):
        """Computes the value for the Sec-WebSocket-Accept header,
        given the value for Sec-WebSocket-Key.
        """
        sha1 = hashlib.sha1()
        sha1.update(utf8(key))
        sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11")  # Magic value
        return native_str(base64.b64encode(sha1.digest()))

    def _challenge_response(self):
        return WebSocketProtocol13.compute_accept_value(
            self.request.headers.get("Sec-Websocket-Key"))

    def _accept_connection(self):
        subprotocols = self.request.headers.get("Sec-WebSocket-Protocol", '')
        subprotocols = [s.strip() for s in subprotocols.split(',')]
        if subprotocols:
            selected = self.handler.select_subprotocol(subprotocols)
            if selected:
                assert selected in subprotocols
                self.handler.set_header("Sec-WebSocket-Protocol", selected)

        extensions = self._parse_extensions_header(self.request.headers)
        for ext in extensions:
            if (ext[0] == 'permessage-deflate' and
                    self._compression_options is not None):
                # TODO: negotiate parameters if compression_options
                # specifies limits.
                self._create_compressors('server', ext[1], self._compression_options)
                if ('client_max_window_bits' in ext[1] and
                        ext[1]['client_max_window_bits'] is None):
                    # Don't echo an offered client_max_window_bits
                    # parameter with no value.
                    del ext[1]['client_max_window_bits']
                self.handler.set_header("Sec-WebSocket-Extensions",
                                        httputil._encode_header(
                                            'permessage-deflate', ext[1]))
                break

        self.handler.clear_header("Content-Type")
        self.handler.set_status(101)
        self.handler.set_header("Upgrade", "websocket")
        self.handler.set_header("Connection", "Upgrade")
        self.handler.set_header("Sec-WebSocket-Accept", self._challenge_response())
        self.handler.finish()

        self.handler._attach_stream()
        self.stream = self.handler.stream

        self.start_pinging()
        self._run_callback(self.handler.open, *self.handler.open_args,
                           **self.handler.open_kwargs)
        self._receive_frame()

    def _parse_extensions_header(self, headers):
        extensions = headers.get("Sec-WebSocket-Extensions", '')
        if extensions:
            return [httputil._parse_header(e.strip())
                    for e in extensions.split(',')]
        return []

    def _process_server_headers(self, key, headers):
        """Process the headers sent by the server to this client connection.

        'key' is the websocket handshake challenge/response key.
        """
        assert headers['Upgrade'].lower() == 'websocket'
        assert headers['Connection'].lower() == 'upgrade'
        accept = self.compute_accept_value(key)
        assert headers['Sec-Websocket-Accept'] == accept

        extensions = self._parse_extensions_header(headers)
        for ext in extensions:
            if (ext[0] == 'permessage-deflate' and
                    self._compression_options is not None):
                self._create_compressors('client', ext[1])
            else:
                raise ValueError("unsupported extension %r", ext)

    def _get_compressor_options(self, side, agreed_parameters, compression_options=None):
        """Converts a websocket agreed_parameters set to keyword arguments
        for our compressor objects.
        """
        options = dict(
            persistent=(side + '_no_context_takeover') not in agreed_parameters)
        wbits_header = agreed_parameters.get(side + '_max_window_bits', None)
        if wbits_header is None:
            options['max_wbits'] = zlib.MAX_WBITS
        else:
            options['max_wbits'] = int(wbits_header)
        options['compression_options'] = compression_options
        return options

    def _create_compressors(self, side, agreed_parameters, compression_options=None):
        # TODO: handle invalid parameters gracefully
        allowed_keys = set(['server_no_context_takeover',
                            'client_no_context_takeover',
                            'server_max_window_bits',
                            'client_max_window_bits'])
        for key in agreed_parameters:
            if key not in allowed_keys:
                raise ValueError("unsupported compression parameter %r" % key)
        other_side = 'client' if (side == 'server') else 'server'
        self._compressor = _PerMessageDeflateCompressor(
            **self._get_compressor_options(side, agreed_parameters, compression_options))
        self._decompressor = _PerMessageDeflateDecompressor(
            **self._get_compressor_options(other_side, agreed_parameters, compression_options))

    def _write_frame(self, fin, opcode, data, flags=0):
        if fin:
            finbit = self.FIN
        else:
            finbit = 0
        frame = struct.pack("B", finbit | opcode | flags)
        data_len = len(data)
        if self.mask_outgoing:
            mask_bit = 0x80
        else:
            mask_bit = 0
        if data_len < 126:
            frame += struct.pack("B", data_len | mask_bit)
        elif data_len <= 0xFFFF:
            frame += struct.pack("!BH", 126 | mask_bit, data_len)
        else:
            frame += struct.pack("!BQ", 127 | mask_bit, data_len)
        if self.mask_outgoing:
            mask = os.urandom(4)
            data = mask + _websocket_mask(mask, data)
        frame += data
        self._wire_bytes_out += len(frame)
        return self.stream.write(frame)

    def write_message(self, message, binary=False):
        """Sends the given message to the client of this Web Socket."""
        if binary:
            opcode = 0x2
        else:
            opcode = 0x1
        message = tornado.escape.utf8(message)
        assert isinstance(message, bytes)
        self._message_bytes_out += len(message)
        flags = 0
        if self._compressor:
            message = self._compressor.compress(message)
            flags |= self.RSV1
        # For historical reasons, write methods in Tornado operate in a semi-synchronous
        # mode in which awaiting the Future they return is optional (But errors can
        # still be raised). This requires us to go through an awkward dance here
        # to transform the errors that may be returned while presenting the same
        # semi-synchronous interface.
        try:
            fut = self._write_frame(True, opcode, message, flags=flags)
        except StreamClosedError:
            raise WebSocketClosedError()

        @gen.coroutine
        def wrapper():
            try:
                yield fut
            except StreamClosedError:
                raise WebSocketClosedError()
        return wrapper()

    def write_ping(self, data):
        """Send ping frame."""
        assert isinstance(data, bytes)
        self._write_frame(True, 0x9, data)

    def _receive_frame(self):
        try:
            self.stream.read_bytes(2, self._on_frame_start)
        except StreamClosedError:
            self._abort()

    def _on_frame_start(self, data):
        self._wire_bytes_in += len(data)
        header, payloadlen = struct.unpack("BB", data)
        self._final_frame = header & self.FIN
        reserved_bits = header & self.RSV_MASK
        self._frame_opcode = header & self.OPCODE_MASK
        self._frame_opcode_is_control = self._frame_opcode & 0x8
        if self._decompressor is not None and self._frame_opcode != 0:
            self._frame_compressed = bool(reserved_bits & self.RSV1)
            reserved_bits &= ~self.RSV1
        if reserved_bits:
            # client is using as-yet-undefined extensions; abort
            self._abort()
            return
        self._masked_frame = bool(payloadlen & 0x80)
        payloadlen = payloadlen & 0x7f
        if self._frame_opcode_is_control and payloadlen >= 126:
            # control frames must have payload < 126
            self._abort()
            return
        try:
            if payloadlen < 126:
                self._frame_length = payloadlen
                if self._masked_frame:
                    self.stream.read_bytes(4, self._on_masking_key)
                else:
                    self._read_frame_data(False)
            elif payloadlen == 126:
                self.stream.read_bytes(2, self._on_frame_length_16)
            elif payloadlen == 127:
                self.stream.read_bytes(8, self._on_frame_length_64)
        except StreamClosedError:
            self._abort()

    def _read_frame_data(self, masked):
        new_len = self._frame_length
        if self._fragmented_message_buffer is not None:
            new_len += len(self._fragmented_message_buffer)
        if new_len > (self.handler.max_message_size or 10 * 1024 * 1024):
            self.close(1009, "message too big")
            return
        self.stream.read_bytes(
            self._frame_length,
            self._on_masked_frame_data if masked else self._on_frame_data)

    def _on_frame_length_16(self, data):
        self._wire_bytes_in += len(data)
        self._frame_length = struct.unpack("!H", data)[0]
        try:
            if self._masked_frame:
                self.stream.read_bytes(4, self._on_masking_key)
            else:
                self._read_frame_data(False)
        except StreamClosedError:
            self._abort()

    def _on_frame_length_64(self, data):
        self._wire_bytes_in += len(data)
        self._frame_length = struct.unpack("!Q", data)[0]
        try:
            if self._masked_frame:
                self.stream.read_bytes(4, self._on_masking_key)
            else:
                self._read_frame_data(False)
        except StreamClosedError:
            self._abort()

    def _on_masking_key(self, data):
        self._wire_bytes_in += len(data)
        self._frame_mask = data
        try:
            self._read_frame_data(True)
        except StreamClosedError:
            self._abort()

    def _on_masked_frame_data(self, data):
        # Don't touch _wire_bytes_in; we'll do it in _on_frame_data.
        self._on_frame_data(_websocket_mask(self._frame_mask, data))

    def _on_frame_data(self, data):
        handled_future = None

        self._wire_bytes_in += len(data)
        if self._frame_opcode_is_control:
            # control frames may be interleaved with a series of fragmented
            # data frames, so control frames must not interact with
            # self._fragmented_*
            if not self._final_frame:
                # control frames must not be fragmented
                self._abort()
                return
            opcode = self._frame_opcode
        elif self._frame_opcode == 0:  # continuation frame
            if self._fragmented_message_buffer is None:
                # nothing to continue
                self._abort()
                return
            self._fragmented_message_buffer += data
            if self._final_frame:
                opcode = self._fragmented_message_opcode
                data = self._fragmented_message_buffer
                self._fragmented_message_buffer = None
        else:  # start of new data message
            if self._fragmented_message_buffer is not None:
                # can't start new message until the old one is finished
                self._abort()
                return
            if self._final_frame:
                opcode = self._frame_opcode
            else:
                self._fragmented_message_opcode = self._frame_opcode
                self._fragmented_message_buffer = data

        if self._final_frame:
            handled_future = self._handle_message(opcode, data)

        if not self.client_terminated:
            if handled_future:
                # on_message is a coroutine, process more frames once it's done.
                handled_future.add_done_callback(
                    lambda future: self._receive_frame())
            else:
                self._receive_frame()

    def _handle_message(self, opcode, data):
        """Execute on_message, returning its Future if it is a coroutine."""
        if self.client_terminated:
            return

        if self._frame_compressed:
            data = self._decompressor.decompress(data)

        if opcode == 0x1:
            # UTF-8 data
            self._message_bytes_in += len(data)
            try:
                decoded = data.decode("utf-8")
            except UnicodeDecodeError:
                self._abort()
                return
            return self._run_callback(self.handler.on_message, decoded)
        elif opcode == 0x2:
            # Binary data
            self._message_bytes_in += len(data)
            return self._run_callback(self.handler.on_message, data)
        elif opcode == 0x8:
            # Close
            self.client_terminated = True
            if len(data) >= 2:
                self.handler.close_code = struct.unpack('>H', data[:2])[0]
            if len(data) > 2:
                self.handler.close_reason = to_unicode(data[2:])
            # Echo the received close code, if any (RFC 6455 section 5.5.1).
            self.close(self.handler.close_code)
        elif opcode == 0x9:
            # Ping
            try:
                self._write_frame(True, 0xA, data)
            except StreamClosedError:
                self._abort()
            self._run_callback(self.handler.on_ping, data)
        elif opcode == 0xA:
            # Pong
            self.last_pong = IOLoop.current().time()
            return self._run_callback(self.handler.on_pong, data)
        else:
            self._abort()

    def close(self, code=None, reason=None):
        """Closes the WebSocket connection."""
        if not self.server_terminated:
            if not self.stream.closed():
                if code is None and reason is not None:
                    code = 1000  # "normal closure" status code
                if code is None:
                    close_data = b''
                else:
                    close_data = struct.pack('>H', code)
                if reason is not None:
                    close_data += utf8(reason)
                try:
                    self._write_frame(True, 0x8, close_data)
                except StreamClosedError:
                    self._abort()
            self.server_terminated = True
        if self.client_terminated:
            if self._waiting is not None:
                self.stream.io_loop.remove_timeout(self._waiting)
                self._waiting = None
            self.stream.close()
        elif self._waiting is None:
            # Give the client a few seconds to complete a clean shutdown,
            # otherwise just close the connection.
            self._waiting = self.stream.io_loop.add_timeout(
                self.stream.io_loop.time() + 5, self._abort)

    @property
    def ping_interval(self):
        interval = self.handler.ping_interval
        if interval is not None:
            return interval
        return 0

    @property
    def ping_timeout(self):
        timeout = self.handler.ping_timeout
        if timeout is not None:
            return timeout
        return max(3 * self.ping_interval, 30)

    def start_pinging(self):
        """Start sending periodic pings to keep the connection alive"""
        if self.ping_interval > 0:
            self.last_ping = self.last_pong = IOLoop.current().time()
            self.ping_callback = PeriodicCallback(
                self.periodic_ping, self.ping_interval * 1000)
            self.ping_callback.start()

    def periodic_ping(self):
        """Send a ping to keep the websocket alive

        Called periodically if the websocket_ping_interval is set and non-zero.
        """
        if self.stream.closed() and self.ping_callback is not None:
            self.ping_callback.stop()
            return

        # Check for timeout on pong. Make sure that we really have
        # sent a recent ping in case the machine with both server and
        # client has been suspended since the last ping.
        now = IOLoop.current().time()
        since_last_pong = now - self.last_pong
        since_last_ping = now - self.last_ping
        if (since_last_ping < 2 * self.ping_interval and
                since_last_pong > self.ping_timeout):
            self.close()
            return

        self.write_ping(b'')
        self.last_ping = now
Beispiel #52
0
class AdaptiveCore:
    """
    The core logic for adaptive deployments, with none of the cluster details

    This class controls our adaptive scaling behavior.  It is intended to be
    used as a super-class or mixin.  It expects the following state and methods:

    **State**

    plan: set
        A set of workers that we think should exist.
        Here and below worker is just a token, often an address or name string

    requested: set
        A set of workers that the cluster class has successfully requested from
        the resource manager.  We expect that resource manager to work to make
        these exist.

    observed: set
        A set of workers that have successfully checked in with the scheduler

    These sets are not necessarily equivalent.  Often plan and requested will
    be very similar (requesting is usually fast) but there may be a large delay
    between requested and observed (often resource managers don't give us what
    we want).

    **Functions**

    target : -> int
        Returns the target number of workers that should exist.
        This is often obtained by querying the scheduler

    workers_to_close : int -> Set[worker]
        Given a target number of workers,
        returns a set of workers that we should close when we're scaling down

    scale_up : int -> None
        Scales the cluster up to a target number of workers, presumably
        changing at least ``plan`` and hopefully eventually also ``requested``

    scale_down : Set[worker] -> None
        Closes the provided set of workers

    Parameters
    ----------
    minimum: int
        The minimum number of allowed workers
    maximum: int | inf
        The maximum number of allowed workers
    wait_count: int
        The number of scale-down requests we should receive before actually
        scaling down
    interval: str
        The amount of time, like ``"1s"`` between checks
    """

    minimum: int
    maximum: int | float
    wait_count: int
    interval: int | float
    periodic_callback: PeriodicCallback | None
    plan: set[WorkerState]
    requested: set[WorkerState]
    observed: set[WorkerState]
    close_counts: defaultdict[WorkerState, int]
    _adapting: bool
    log: deque[tuple[float, dict]]

    def __init__(
        self,
        minimum: int = 0,
        maximum: int | float = math.inf,
        wait_count: int = 3,
        interval: str | int | float | timedelta | None = "1s",
    ):
        if not isinstance(maximum, int) and not math.isinf(maximum):
            raise TypeError(f"maximum must be int or inf; got {maximum}")

        self.minimum = minimum
        self.maximum = maximum
        self.wait_count = wait_count
        self.interval = parse_timedelta(interval, "seconds")
        self.periodic_callback = None

        def f():
            try:
                self.periodic_callback.start()
            except AttributeError:
                pass

        if self.interval:
            import weakref

            self_ref = weakref.ref(self)

            async def _adapt():
                core = self_ref()
                if core:
                    await core.adapt()

            self.periodic_callback = PeriodicCallback(_adapt,
                                                      self.interval * 1000)
            self.loop.add_callback(f)

        try:
            self.plan = set()
            self.requested = set()
            self.observed = set()
        except Exception:
            pass

        # internal state
        self.close_counts = defaultdict(int)
        self._adapting = False
        self.log = deque(maxlen=10000)

    def stop(self) -> None:
        logger.info("Adaptive stop")

        if self.periodic_callback:
            self.periodic_callback.stop()
            self.periodic_callback = None

    async def target(self) -> int:
        """The target number of workers that should exist"""
        raise NotImplementedError()

    async def workers_to_close(self, target: int) -> list:
        """
        Give a list of workers to close that brings us down to target workers
        """
        # TODO, improve me with something that thinks about current load
        return list(self.observed)[target:]

    async def safe_target(self) -> int:
        """Used internally, like target, but respects minimum/maximum"""
        n = await self.target()
        if n > self.maximum:
            n = cast(int, self.maximum)

        if n < self.minimum:
            n = self.minimum

        return n

    async def scale_down(self, n: int) -> None:
        raise NotImplementedError()

    async def scale_up(self, workers: Iterable) -> None:
        raise NotImplementedError()

    async def recommendations(self, target: int) -> dict:
        """
        Make scale up/down recommendations based on current state and target
        """
        plan = self.plan
        requested = self.requested
        observed = self.observed

        if target == len(plan):
            self.close_counts.clear()
            return {"status": "same"}

        if target > len(plan):
            self.close_counts.clear()
            return {"status": "up", "n": target}

        # target < len(plan)
        not_yet_arrived = requested - observed
        to_close = set()
        if not_yet_arrived:
            to_close.update(toolz.take(len(plan) - target, not_yet_arrived))

        if target < len(plan) - len(to_close):
            L = await self.workers_to_close(target=target)
            to_close.update(L)

        firmly_close = set()
        for w in to_close:
            self.close_counts[w] += 1
            if self.close_counts[w] >= self.wait_count:
                firmly_close.add(w)

        for k in list(self.close_counts):  # clear out unseen keys
            if k in firmly_close or k not in to_close:
                del self.close_counts[k]

        if firmly_close:
            return {"status": "down", "workers": list(firmly_close)}
        else:
            return {"status": "same"}

    async def adapt(self) -> None:
        """
        Check the current state, make recommendations, call scale

        This is the main event of the system
        """
        if self._adapting:  # Semaphore to avoid overlapping adapt calls
            return
        self._adapting = True
        status = None

        try:

            target = await self.safe_target()
            recommendations = await self.recommendations(target)

            if recommendations["status"] != "same":
                self.log.append((time(), dict(recommendations)))

            status = recommendations.pop("status")
            if status == "same":
                return
            if status == "up":
                await self.scale_up(**recommendations)
            if status == "down":
                await self.scale_down(**recommendations)
        except OSError:
            if status != "down":
                logger.error("Adaptive stopping due to error", exc_info=True)
                self.stop()
            else:
                logger.error("Error during adaptive downscaling. Ignoring.",
                             exc_info=True)
        finally:
            self._adapting = False

    def __del__(self):
        self.stop()

    @property
    def loop(self) -> IOLoop:
        return IOLoop.current()
Beispiel #53
0
class DataStore:
    def __init__(self, config):
        self.server = config.get_server()
        self.temp_store_size = config.getint('temperature_store_size', 1200)
        self.gcode_store_size = config.getint('gcode_store_size', 1000)

        # Temperature Store Tracking
        self.last_temps = {}
        self.gcode_queue = deque(maxlen=self.gcode_store_size)
        self.temperature_store = {}
        self.temp_update_cb = PeriodicCallback(
            self._update_temperature_store, TEMPERATURE_UPDATE_MS)

        # Register status update event
        self.server.register_event_handler(
            "server:status_update", self._set_current_temps)
        self.server.register_event_handler(
            "server:gcode_response", self._update_gcode_store)
        self.server.register_event_handler(
            "server:klippy_ready", self._init_sensors)

        # Register endpoints
        self.server.register_endpoint(
            "/server/temperature_store", ['GET'],
            self._handle_temp_store_request)
        self.server.register_endpoint(
            "/server/gcode_store", ['GET'],
            self._handle_gcode_store_request)

    async def _init_sensors(self):
        klippy_apis = self.server.lookup_component('klippy_apis')
        # Fetch sensors
        try:
            result = await klippy_apis.query_objects({'heaters': None})
        except self.server.error as e:
            logging.info(f"Error Configuring Sensors: {e}")
            return
        sensors = result.get("heaters", {}).get("available_sensors", [])

        if sensors:
            # Add Subscription
            sub = {s: None for s in sensors}
            try:
                status = await klippy_apis.subscribe_objects(sub)
            except self.server.error as e:
                logging.info(f"Error subscribing to sensors: {e}")
                return
            logging.info(f"Configuring available sensors: {sensors}")
            new_store = {}
            for sensor in sensors:
                fields = list(status.get(sensor, {}).keys())
                if sensor in self.temperature_store:
                    new_store[sensor] = self.temperature_store[sensor]
                else:
                    new_store[sensor] = {
                        'temperatures': deque(maxlen=self.temp_store_size)}
                    for item in ["target", "power", "speed"]:
                        if item in fields:
                            new_store[sensor][f"{item}s"] = deque(
                                maxlen=self.temp_store_size)
                if sensor not in self.last_temps:
                    self.last_temps[sensor] = (0., 0., 0., 0.)
            self.temperature_store = new_store
            # Prune unconfigured sensors in self.last_temps
            for sensor in list(self.last_temps.keys()):
                if sensor not in self.temperature_store:
                    del self.last_temps[sensor]
            # Update initial temperatures
            self._set_current_temps(status)
            self.temp_update_cb.start()
        else:
            logging.info("No sensors found")
            self.last_temps = {}
            self.temperature_store = {}
            self.temp_update_cb.stop()

    def _set_current_temps(self, data):
        for sensor in self.temperature_store:
            if sensor in data:
                last_val = self.last_temps[sensor]
                self.last_temps[sensor] = (
                    round(data[sensor].get('temperature', last_val[0]), 2),
                    data[sensor].get('target', last_val[1]),
                    data[sensor].get('power', last_val[2]),
                    data[sensor].get('speed', last_val[3]))

    def _update_temperature_store(self):
        # XXX - If klippy is not connected, set values to zero
        # as they are unknown?
        for sensor, vals in self.last_temps.items():
            self.temperature_store[sensor]['temperatures'].append(vals[0])
            for val, item in zip(vals[1:], ["targets", "powers", "speeds"]):
                if item in self.temperature_store[sensor]:
                    self.temperature_store[sensor][item].append(val)

    async def _handle_temp_store_request(self, web_request):
        store = {}
        for name, sensor in self.temperature_store.items():
            store[name] = {k: list(v) for k, v in sensor.items()}
        return store

    async def close(self):
        self.temp_update_cb.stop()

    def _update_gcode_store(self, response):
        curtime = time.time()
        self.gcode_queue.append(
            {'message': response, 'time': curtime, 'type': "response"})

    def store_gcode_command(self, script):
        curtime = time.time()
        for cmd in script.split('\n'):
            cmd = cmd.strip()
            if not cmd:
                continue
            self.gcode_queue.append(
                {'message': script, 'time': curtime, 'type': "command"})

    async def _handle_gcode_store_request(self, web_request):
        count = web_request.get_int("count", None)
        if count is not None:
            gc_responses = list(self.gcode_queue)[-count:]
        else:
            gc_responses = list(self.gcode_queue)
        return {'gcode_store': gc_responses}
Beispiel #54
0
class Engine(BaseEngine):

    NAME = 'Redis'

    OK_RESPONSE = b'OK'

    def __init__(self, *args, **kwargs):
        super(Engine, self).__init__(*args, **kwargs)

        if not self.options.redis_url:
            self.host = self.options.redis_host
            self.port = self.options.redis_port
            self.password = self.options.redis_password
            self.db = self.options.redis_db
        else:
            # according to https://devcenter.heroku.com/articles/redistogo
            parsed_url = urlparse.urlparse(self.options.redis_url)
            self.host = parsed_url.hostname
            self.port = int(parsed_url.port)
            self.db = 0
            self.password = parsed_url.password

        self.connection_check = PeriodicCallback(self.check_connection, 1000)
        self._need_reconnect = False

        self.subscriber = toredis.Client(io_loop=self.io_loop)
        self.publisher = toredis.Client(io_loop=self.io_loop)
        self.worker = toredis.Client(io_loop=self.io_loop)

        self.subscriptions = {}

    def initialize(self):
        self.connect()
        logger.info("Redis engine at {0}:{1} (db {2})".format(
            self.host, self.port, self.db))

    def on_auth(self, res):
        if res != self.OK_RESPONSE:
            logger.error("auth failed: {0}".format(res))

    def on_subscriber_select(self, res):
        """
        After selecting subscriber database subscribe on channels
        """
        if res != self.OK_RESPONSE:
            # error returned
            logger.error("select database failed: {0}".format(res))
            self._need_reconnect = True
            return

        self.subscriber.subscribe(self.admin_channel_name,
                                  callback=self.on_redis_message)
        self.subscriber.subscribe(self.control_channel_name,
                                  callback=self.on_redis_message)

        for subscription in self.subscriptions.copy():
            if subscription not in self.subscriptions:
                continue
            self.subscriber.subscribe(subscription,
                                      callback=self.on_redis_message)

    def on_select(self, res):
        if res != self.OK_RESPONSE:
            logger.error("select database failed: {0}".format(res))
            self._need_reconnect = True

    def connect(self):
        """
        Connect from scratch
        """
        try:
            self.subscriber.connect(host=self.host, port=self.port)
            self.publisher.connect(host=self.host, port=self.port)
            self.worker.connect(host=self.host, port=self.port)
        except Exception as e:
            logger.error("error connecting to Redis server: %s" % (str(e)))
        else:
            if self.password:
                self.subscriber.auth(self.password, callback=self.on_auth)
                self.publisher.auth(self.password, callback=self.on_auth)
                self.worker.auth(self.password, callback=self.on_auth)

            self.subscriber.select(self.db, callback=self.on_subscriber_select)
            self.publisher.select(self.db, callback=self.on_select)
            self.worker.select(self.db, callback=self.on_select)

        self.connection_check.stop()
        self.connection_check.start()

    def check_connection(self):
        conn_statuses = [
            self.subscriber.is_connected(),
            self.publisher.is_connected(),
            self.worker.is_connected()
        ]
        connection_dropped = not all(conn_statuses)
        if connection_dropped or self._need_reconnect:
            logger.info('reconnecting to Redis')
            self._need_reconnect = False
            self.connect()

    def _publish(self, channel, message):
        try:
            self.publisher.publish(channel, message)
        except StreamClosedError as e:
            self._need_reconnect = True
            logger.error(e)
            return False
        else:
            return True

    @coroutine
    def publish_message(self,
                        channel,
                        body,
                        method=BaseEngine.DEFAULT_PUBLISH_METHOD):
        """
        Publish message into channel of stream.
        """
        response = Response()
        method = method or self.DEFAULT_PUBLISH_METHOD
        response.method = method
        response.body = body
        to_publish = response.as_message()
        result = self._publish(channel, to_publish)
        raise Return((result, None))

    @coroutine
    def publish_control_message(self, message):
        result = self._publish(self.control_channel_name, json_encode(message))
        raise Return((result, None))

    @coroutine
    def publish_admin_message(self, message):
        result = self._publish(self.admin_channel_name, json_encode(message))
        raise Return((result, None))

    @coroutine
    def on_redis_message(self, redis_message):
        """
        Got message from Redis, dispatch it into right message handler.
        """
        msg_type = redis_message[0]
        if six.PY3:
            msg_type = msg_type.decode()

        if msg_type != 'message':
            return

        channel = redis_message[1]
        if six.PY3:
            channel = channel.decode()

        if channel == self.control_channel_name:
            yield self.handle_control_message(json_decode(redis_message[2]))
        elif channel == self.admin_channel_name:
            yield self.handle_admin_message(json_decode(redis_message[2]))
        else:
            yield self.handle_message(channel, redis_message[2])

    @coroutine
    def handle_admin_message(self, message):
        message = json_encode(message)
        for uid, connection in six.iteritems(
                self.application.admin_connections):
            if uid not in self.application.admin_connections:
                continue
            connection.send(message)

        raise Return((True, None))

    @coroutine
    def handle_control_message(self, message):
        """
        Handle control message.
        """
        app_id = message.get("app_id")
        method = message.get("method")
        params = message.get("params")

        if app_id and app_id == self.application.uid:
            # application id must be set when we don't want to do
            # make things twice for the same application. Setting
            # app_id means that we don't want to process control
            # message when it is appear in application instance if
            # application uid matches app_id
            raise Return((True, None))

        func = getattr(self.application, 'handle_%s' % method, None)
        if not func:
            raise Return((None, self.application.METHOD_NOT_FOUND))

        result, error = yield func(params)
        raise Return((result, error))

    @coroutine
    def handle_message(self, channel, message_data):
        if channel not in self.subscriptions:
            raise Return((True, None))
        for uid, client in six.iteritems(self.subscriptions[channel]):
            if channel in self.subscriptions and uid in self.subscriptions[
                    channel]:
                yield client.send(message_data)

    def subscribe_key(self, subscription_key):
        self.subscriber.subscribe(subscription_key,
                                  callback=self.on_redis_message)

    def unsubscribe_key(self, subscription_key):
        self.subscriber.unsubscribe(subscription_key)

    @coroutine
    def add_subscription(self, project_id, channel, client):
        """
        Subscribe application on channel if necessary and register client
        to receive messages from that channel.
        """
        subscription_key = self.get_subscription_key(project_id, channel)

        self.subscribe_key(subscription_key)

        if subscription_key not in self.subscriptions:
            self.subscriptions[subscription_key] = {}

        self.subscriptions[subscription_key][client.uid] = client

        raise Return((True, None))

    @coroutine
    def remove_subscription(self, project_id, channel, client):
        """
        Unsubscribe application from channel if necessary and prevent client
        from receiving messages from that channel.
        """
        subscription_key = self.get_subscription_key(project_id, channel)

        try:
            del self.subscriptions[subscription_key][client.uid]
        except KeyError:
            pass

        try:
            if not self.subscriptions[subscription_key]:
                self.unsubscribe_key(subscription_key)
                del self.subscriptions[subscription_key]
        except KeyError:
            pass

        raise Return((True, None))

    def get_presence_hash_key(self, project_id, channel):
        return "%s:presence:hash:%s:%s" % (self.prefix, project_id, channel)

    def get_presence_set_key(self, project_id, channel):
        return "%s:presence:set:%s:%s" % (self.prefix, project_id, channel)

    def get_history_list_key(self, project_id, channel):
        return "%s:history:list:%s:%s" % (self.prefix, project_id, channel)

    @coroutine
    def add_presence(self,
                     project_id,
                     channel,
                     uid,
                     user_info,
                     presence_timeout=None):
        now = int(time.time())
        expire_at = now + (presence_timeout or self.presence_timeout)
        hash_key = self.get_presence_hash_key(project_id, channel)
        set_key = self.get_presence_set_key(project_id, channel)
        try:
            pipeline = self.worker.pipeline()
            pipeline.multi()
            pipeline.zadd(set_key, {uid: expire_at})
            pipeline.hset(hash_key, uid, json_encode(user_info))
            pipeline.execute()
            yield Task(pipeline.send)
        except StreamClosedError as e:
            raise Return((None, e))
        else:
            raise Return((True, None))

    @coroutine
    def remove_presence(self, project_id, channel, uid):
        hash_key = self.get_presence_hash_key(project_id, channel)
        set_key = self.get_presence_set_key(project_id, channel)
        try:
            pipeline = self.worker.pipeline()
            pipeline.hdel(hash_key, uid)
            pipeline.zrem(set_key, uid)
            yield Task(pipeline.send)
        except StreamClosedError as e:
            raise Return((None, e))
        else:
            raise Return((True, None))

    @coroutine
    def get_presence(self, project_id, channel):
        now = int(time.time())
        hash_key = self.get_presence_hash_key(project_id, channel)
        set_key = self.get_presence_set_key(project_id, channel)
        try:
            expired_keys = yield Task(self.worker.zrangebyscore, set_key, 0,
                                      now)
            if expired_keys:
                pipeline = self.worker.pipeline()
                pipeline.zremrangebyscore(set_key, 0, now)
                pipeline.hdel(hash_key, [x.decode() for x in expired_keys])
                yield Task(pipeline.send)
            data = yield Task(self.worker.hgetall, hash_key)
        except StreamClosedError as e:
            raise Return((None, e))
        else:
            raise Return((dict_from_list(data), None))

    @coroutine
    def add_history_message(self,
                            project_id,
                            channel,
                            message,
                            history_size=None,
                            history_expire=0):
        history_size = history_size or self.history_size
        history_list_key = self.get_history_list_key(project_id, channel)
        try:
            pipeline = self.worker.pipeline()
            pipeline.lpush(history_list_key, json_encode(message))
            pipeline.ltrim(history_list_key, 0, history_size - 1)
            if history_expire:
                pipeline.expire(history_list_key, history_expire)
            else:
                pipeline.persist(history_list_key)
            yield Task(pipeline.send)
        except StreamClosedError as e:
            raise Return((None, e))
        else:
            raise Return((True, None))

    @coroutine
    def get_history(self, project_id, channel):
        history_list_key = self.get_history_list_key(project_id, channel)
        try:
            data = yield Task(self.worker.lrange, history_list_key, 0, -1)
        except StreamClosedError as e:
            raise Return((None, e))
        else:
            raise Return(([json_decode(x.decode()) for x in data], None))
Beispiel #55
0
class Reader(Client):
    r"""
    Reader provides high-level functionality for building robust NSQ consumers in Python
    on top of the async module.

    Reader receives messages over the specified ``topic/channel`` and calls ``message_handler``
    for each message (up to ``max_tries``).

    Multiple readers can be instantiated in a single process (to consume from multiple
    topics/channels at once).

    Supports various hooks to modify behavior when heartbeats are received, to temporarily
    disable the reader, and pre-process/validate messages.

    When supplied a list of ``nsqlookupd`` addresses, it will periodically poll those
    addresses to discover new producers of the specified ``topic``.

    It maintains a sufficient RDY count based on the # of producers and your configured
    ``max_in_flight``.

    Handlers should be defined as shown in the examples below. The ``message_handler``
    callback function receives a :class:`nsq.Message` object that has instance methods
    :meth:`nsq.Message.finish`, :meth:`nsq.Message.requeue`, and :meth:`nsq.Message.touch`
    which can be used to respond to ``nsqd``. As an alternative to explicitly calling these
    response methods, the handler function can simply return ``True`` to finish the message,
    or ``False`` to requeue it. If the handler function calls :meth:`nsq.Message.enable_async`,
    then automatic finish/requeue is disabled, allowing the :class:`nsq.Message` to finish or
    requeue in a later async callback or context. The handler function may also be a coroutine,
    in which case Message async handling is enabled automatically, but the coroutine
    can still return a final value of True/False to automatically finish/requeue the message.

    After re-queueing a message, the handler will backoff from processing additional messages
    for an increasing delay (calculated exponentially based on consecutive failures up to
    ``max_backoff_duration``).

    Synchronous example::

        import nsq

        def handler(message):
            print message
            return True

        r = nsq.Reader(message_handler=handler,
                lookupd_http_addresses=['http://127.0.0.1:4161'],
                topic='nsq_reader', channel='asdf', lookupd_poll_interval=15)
        nsq.run()

    Asynchronous example::

        import nsq

        buf = []

        def process_message(message):
            global buf
            message.enable_async()
            # cache the message for later processing
            buf.append(message)
            if len(buf) >= 3:
                for msg in buf:
                    print msg
                    msg.finish()
                buf = []
            else:
                print 'deferring processing'

        r = nsq.Reader(message_handler=process_message,
                lookupd_http_addresses=['http://127.0.0.1:4161'],
                topic='nsq_reader', channel='async', max_in_flight=9)
        nsq.run()

    :param message_handler: the callable that will be executed for each message received

    :param topic: specifies the desired NSQ topic

    :param channel: specifies the desired NSQ channel

    :param name: a string that is used for logging messages (defaults to 'topic:channel')

    :param nsqd_tcp_addresses: a sequence of string addresses of the nsqd instances this reader
        should connect to

    :param lookupd_http_addresses: a sequence of string addresses of the nsqlookupd instances this
        reader should query for producers of the specified topic

    :param max_tries: the maximum number of attempts the reader will make to process a message after
        which messages will be automatically discarded

    :param max_in_flight: the maximum number of messages this reader will pipeline for processing.
        this value will be divided evenly amongst the configured/discovered nsqd producers

    :param lookupd_poll_interval: the amount of time in seconds between querying all of the supplied
        nsqlookupd instances.  a random amount of time based on this value will be initially
        introduced in order to add jitter when multiple readers are running

    :param lookupd_poll_jitter: The maximum fractional amount of jitter to add to the
        lookupd poll loop. This helps evenly distribute requests even if multiple consumers
        restart at the same time.

    :param lookupd_connect_timeout: the amount of time in seconds to wait for
        a connection to ``nsqlookupd`` to be established

    :param lookupd_request_timeout: the amount of time in seconds to wait for
        a request to ``nsqlookupd`` to complete.

    :param low_rdy_idle_timeout: the amount of time in seconds to wait for a message from a producer
        when in a state where RDY counts are re-distributed (ie. max_in_flight < num_producers)

    :param max_backoff_duration: the maximum time we will allow a backoff state to last in seconds

    :param \*\*kwargs: passed to :class:`nsq.AsyncConn` initialization
    """
    def __init__(
            self,
            topic,
            channel,
            message_handler=None,
            name=None,
            nsqd_tcp_addresses=None,
            lookupd_http_addresses=None,
            max_tries=5,
            max_in_flight=1,
            lookupd_poll_interval=60,
            low_rdy_idle_timeout=10,
            max_backoff_duration=128,
            lookupd_poll_jitter=0.3,
            lookupd_connect_timeout=1,
            lookupd_request_timeout=2,
            **kwargs):
        super(Reader, self).__init__(**kwargs)

        assert isinstance(topic, string_types) and len(topic) > 0
        assert isinstance(channel, string_types) and len(channel) > 0
        assert isinstance(max_in_flight, int) and max_in_flight > 0
        assert isinstance(max_backoff_duration, (int, float)) and max_backoff_duration > 0
        assert isinstance(name, string_types + (None.__class__,))
        assert isinstance(lookupd_poll_interval, int)
        assert isinstance(lookupd_poll_jitter, float)
        assert isinstance(lookupd_connect_timeout, int)
        assert isinstance(lookupd_request_timeout, int)

        assert lookupd_poll_jitter >= 0 and lookupd_poll_jitter <= 1

        if nsqd_tcp_addresses:
            if not isinstance(nsqd_tcp_addresses, (list, set, tuple)):
                assert isinstance(nsqd_tcp_addresses, string_types)
                nsqd_tcp_addresses = [nsqd_tcp_addresses]
        else:
            nsqd_tcp_addresses = []

        if lookupd_http_addresses:
            if not isinstance(lookupd_http_addresses, (list, set, tuple)):
                assert isinstance(lookupd_http_addresses, string_types)
                lookupd_http_addresses = [lookupd_http_addresses]
            random.shuffle(lookupd_http_addresses)
        else:
            lookupd_http_addresses = []

        assert nsqd_tcp_addresses or lookupd_http_addresses

        self.name = name or (topic + ':' + channel)
        self.message_handler = None
        if message_handler:
            self.set_message_handler(message_handler)
        self.topic = topic
        self.channel = channel
        self.nsqd_tcp_addresses = nsqd_tcp_addresses
        self.lookupd_http_addresses = lookupd_http_addresses
        self.lookupd_query_index = 0
        self.max_tries = max_tries
        self.max_in_flight = max_in_flight
        self.low_rdy_idle_timeout = low_rdy_idle_timeout
        self.total_rdy = 0
        self.need_rdy_redistributed = False
        self.lookupd_poll_interval = lookupd_poll_interval
        self.lookupd_poll_jitter = lookupd_poll_jitter
        self.lookupd_connect_timeout = lookupd_connect_timeout
        self.lookupd_request_timeout = lookupd_request_timeout
        self.random_rdy_ts = time.time()

        # Verify keyword arguments
        valid_args = func_args(AsyncConn.__init__)
        diff = set(kwargs) - set(valid_args)
        assert len(diff) == 0, 'Invalid keyword argument(s): %s' % list(diff)

        self.conn_kwargs = kwargs

        self.backoff_timer = BackoffTimer(0, max_backoff_duration)
        self.backoff_block = False
        self.backoff_block_completed = True

        self.conns = {}
        self.connection_attempts = {}
        self.http_client = tornado.httpclient.AsyncHTTPClient()

        # will execute when run() is called (for all Reader instances)
        self.io_loop.add_callback(self._run)

        self.redist_periodic = None
        self.query_periodic = None

    def _run(self):
        assert self.message_handler, "you must specify the Reader's message_handler"

        logger.info('[%s] starting reader for %s/%s...', self.name, self.topic, self.channel)

        for addr in self.nsqd_tcp_addresses:
            address, port = addr.split(':')
            self.connect_to_nsqd(address, int(port))

        self.redist_periodic = PeriodicCallback(
            self._redistribute_rdy_state,
            5 * 1000,
        )
        self.redist_periodic.start()

        if not self.lookupd_http_addresses:
            return

        # trigger the first lookup query manually
        self.io_loop.spawn_callback(self.query_lookupd)

        self.query_periodic = PeriodicCallback(
            self.query_lookupd,
            self.lookupd_poll_interval * 1000,
        )

        # randomize the time we start this poll loop so that all
        # consumers don't query at exactly the same time
        delay = random.random() * self.lookupd_poll_interval * self.lookupd_poll_jitter
        self.io_loop.call_later(delay, self.query_periodic.start)

    def close(self):
        """
        Closes all connections stops all periodic callbacks
        """
        for conn in self.conns.values():
            conn.close()

        self.redist_periodic.stop()
        if self.query_periodic is not None:
            self.query_periodic.stop()

    def set_message_handler(self, message_handler):
        """
        Assigns the callback method to be executed for each message received

        :param message_handler: a callable that takes a single argument
        """
        assert callable(message_handler), 'message_handler must be callable'
        self.message_handler = message_handler

    def _connection_max_in_flight(self):
        return max(1, self.max_in_flight // max(1, len(self.conns)))

    def is_starved(self):
        """
        Used to identify when buffered messages should be processed and responded to.

        When max_in_flight > 1 and you're batching messages together to perform work
        is isn't possible to just compare the len of your list of buffered messages against
        your configured max_in_flight (because max_in_flight may not be evenly divisible
        by the number of producers you're connected to, ie. you might never get that many
        messages... it's a *max*).

        Example::

            def message_handler(self, nsq_msg, reader):
                # buffer messages
                if reader.is_starved():
                    # perform work

            reader = nsq.Reader(...)
            reader.set_message_handler(functools.partial(message_handler, reader=reader))
            nsq.run()
        """
        for conn in itervalues(self.conns):
            if conn.in_flight > 0 and conn.in_flight >= (conn.last_rdy * 0.85):
                return True
        return False

    def _on_message(self, conn, message, **kwargs):
        try:
            self._handle_message(conn, message)
        except Exception:
            logger.exception('[%s:%s] failed to handle_message() %r', conn.id, self.name, message)

    def _handle_message(self, conn, message):
        self._maybe_update_rdy(conn)

        result = False
        try:
            if 0 < self.max_tries < message.attempts:
                self.giving_up(message)
                return message.finish()
            pre_processed_message = self.preprocess_message(message)
            if not self.validate_message(pre_processed_message):
                return message.finish()
            result = self.process_message(message)
        except Exception:
            logger.exception('[%s:%s] uncaught exception while handling message %s body:%r',
                             conn.id, self.name, message.id, message.body)
            if not message.has_responded():
                return message.requeue()

        if result not in (True, False, None):
            # assume handler returned a Future or Coroutine
            message.enable_async()
            fut = tornado.gen.convert_yielded(result)
            fut.add_done_callback(functools.partial(self._maybe_finish, message))

        elif not message.is_async() and not message.has_responded():
            assert result is not None, 'ambiguous return value for synchronous mode'
            if result:
                return message.finish()
            return message.requeue()

    def _maybe_finish(self, message, fut):
        if not message.has_responded():
            try:
                if fut.result():
                    message.finish()
                    return
            except Exception:
                pass
            message.requeue()

    def _maybe_update_rdy(self, conn):
        if self.backoff_timer.get_interval() or self.max_in_flight == 0:
            return

        # Update RDY in 2 cases:
        #     1. On a new connection or in backoff we start with a tentative RDY
        #        count of 1.  After successfully receiving a first message we go to
        #        full throttle.
        #     2. After a change in connection count or max_in_flight we adjust to the new
        #        connection_max_in_flight.
        conn_max_in_flight = self._connection_max_in_flight()
        if conn.rdy == 1 or conn.rdy != conn_max_in_flight:
            self._send_rdy(conn, conn_max_in_flight)

    def _finish_backoff_block(self):
        self.backoff_block = False

        # we must have raced and received a message out of order that resumed
        # so just complete the backoff block
        if not self.backoff_timer.get_interval():
            self._complete_backoff_block()
            return

        # test the waters after finishing a backoff round
        # if we have no connections, this will happen when a new connection gets RDY 1
        if not self.conns or self.max_in_flight == 0:
            return

        conn = random.choice(list(self.conns.values()))
        logger.info('[%s:%s] testing backoff state with RDY 1', conn.id, self.name)
        self._send_rdy(conn, 1)

        # for tests
        return conn

    def _on_backoff_resume(self, success, **kwargs):
        if success:
            self.backoff_timer.success()
        elif success is False and not self.backoff_block:
            self.backoff_timer.failure()

        self._enter_continue_or_exit_backoff()

    def _complete_backoff_block(self):
        self.backoff_block_completed = True
        rdy = self._connection_max_in_flight()
        logger.info('[%s] backoff complete, resuming normal operation (%d connections)',
                    self.name, len(self.conns))
        for c in self.conns.values():
            self._send_rdy(c, rdy)

    def _enter_continue_or_exit_backoff(self):
        # Take care of backoff in the appropriate cases.  When this
        # happens, we set a failure on the backoff timer and set the RDY count to zero.
        # Once the backoff time has expired, we allow *one* of the connections let
        # a single message through to test the water.  This will continue until we
        # reach no backoff in which case we go back to the normal RDY count.

        current_backoff_interval = self.backoff_timer.get_interval()

        # do nothing
        if self.backoff_block:
            return

        # we're out of backoff completely, return to full blast for all conns
        if not self.backoff_block_completed and not current_backoff_interval:
            self._complete_backoff_block()
            return

        # enter or continue a backoff iteration
        if current_backoff_interval:
            self._start_backoff_block()

    def _start_backoff_block(self):
        self.backoff_block = True
        self.backoff_block_completed = False
        backoff_interval = self.backoff_timer.get_interval()

        logger.info('[%s] backing off for %0.2f seconds (%d connections)',
                    self.name, backoff_interval, len(self.conns))
        for c in self.conns.values():
            self._send_rdy(c, 0)

        self.io_loop.call_later(backoff_interval, self._finish_backoff_block)

    def _rdy_retry(self, conn, value):
        conn.rdy_timeout = None
        self._send_rdy(conn, value)

    def _send_rdy(self, conn, value):
        if conn.rdy_timeout:
            self.io_loop.remove_timeout(conn.rdy_timeout)
            conn.rdy_timeout = None

        if value and (self.disabled() or self.max_in_flight == 0):
            logger.info('[%s:%s] disabled, delaying RDY state change', conn.id, self.name)
            rdy_retry_callback = functools.partial(self._rdy_retry, conn, value)
            conn.rdy_timeout = self.io_loop.call_later(15, rdy_retry_callback)
            return

        if value > conn.max_rdy_count:
            value = conn.max_rdy_count

        new_rdy = max(self.total_rdy - conn.rdy + value, 0)
        if conn.send_rdy(value):
            self.total_rdy = new_rdy

    def connect_to_nsqd(self, host, port):
        """
        Adds a connection to ``nsqd`` at the specified address.

        :param host: the address to connect to
        :param port: the port to connect to
        """
        assert isinstance(host, string_types)
        assert isinstance(port, int)

        conn = AsyncConn(host, port, **self.conn_kwargs)
        conn.on('identify', self._on_connection_identify)
        conn.on('identify_response', self._on_connection_identify_response)
        conn.on('auth', self._on_connection_auth)
        conn.on('auth_response', self._on_connection_auth_response)
        conn.on('error', self._on_connection_error)
        conn.on('close', self._on_connection_close)
        conn.on('ready', self._on_connection_ready)
        conn.on('message', self._on_message)
        conn.on('heartbeat', self._on_heartbeat)
        conn.on('backoff', functools.partial(self._on_backoff_resume, success=False))
        conn.on('resume', functools.partial(self._on_backoff_resume, success=True))
        conn.on('continue', functools.partial(self._on_backoff_resume, success=None))

        if conn.id in self.conns:
            return

        # only attempt to re-connect once every 10s per destination
        # this throttles reconnects to failed endpoints
        now = time.time()
        last_connect_attempt = self.connection_attempts.get(conn.id)
        if last_connect_attempt and last_connect_attempt > now - 10:
            return
        self.connection_attempts[conn.id] = now

        logger.info('[%s:%s] connecting to nsqd', conn.id, self.name)
        conn.connect()

        return conn

    def _on_connection_ready(self, conn, **kwargs):
        conn.send(protocol.subscribe(self.topic, self.channel))
        # re-check to make sure another connection didn't beat this one done
        if conn.id in self.conns:
            logger.warning(
                '[%s:%s] connected to NSQ but anothermatching connection already exists',
                conn.id, self.name)
            conn.close()
            return

        if conn.max_rdy_count < self.max_in_flight:
            logger.warning(
                '[%s:%s] max RDY count %d < reader max in flight %d, truncation possible',
                conn.id, self.name, conn.max_rdy_count, self.max_in_flight)

        self.conns[conn.id] = conn

        conn_max_in_flight = self._connection_max_in_flight()
        for c in self.conns.values():
            if c.rdy > conn_max_in_flight:
                self._send_rdy(c, conn_max_in_flight)

        # we send an initial RDY of 1 up to our configured max_in_flight
        # this resolves two cases:
        #    1. `max_in_flight >= num_conns` ensuring that no connections are ever
        #       *initially* starved since redistribute won't apply
        #    2. `max_in_flight < num_conns` ensuring that we never exceed max_in_flight
        #       and rely on the fact that redistribute will handle balancing RDY across conns
        if not self.backoff_timer.get_interval() or len(self.conns) == 1:
            # only send RDY 1 if we're not in backoff (some other conn
            # should be testing the waters)
            # (but always send it if we're the first)
            self._send_rdy(conn, 1)

    def _on_connection_close(self, conn, **kwargs):
        if conn.id in self.conns:
            del self.conns[conn.id]

        self.total_rdy = max(self.total_rdy - conn.rdy, 0)

        logger.warning('[%s:%s] connection closed', conn.id, self.name)

        if (conn.rdy_timeout or conn.rdy) and \
                (len(self.conns) == self.max_in_flight or self.backoff_timer.get_interval()):
            # we're toggling out of (normal) redistribution cases and this conn
            # had a RDY count...
            #
            # trigger RDY redistribution to make sure this RDY is moved
            # to a new connection
            self.need_rdy_redistributed = True

        if conn.rdy_timeout:
            self.io_loop.remove_timeout(conn.rdy_timeout)
            conn.rdy_timeout = None

        if not self.lookupd_http_addresses:
            # automatically reconnect to nsqd addresses when not using lookupd
            logger.info('[%s:%s] attempting to reconnect in 15s', conn.id, self.name)
            reconnect_callback = functools.partial(self.connect_to_nsqd,
                                                   host=conn.host, port=conn.port)
            self.io_loop.call_later(15, reconnect_callback)

    @tornado.gen.coroutine
    def query_lookupd(self):
        """
        Trigger a query of the configured ``nsq_lookupd_http_addresses``.
        """
        endpoint = self.lookupd_http_addresses[self.lookupd_query_index]
        self.lookupd_query_index = (self.lookupd_query_index + 1) % len(self.lookupd_http_addresses)

        # urlsplit() is faulty if scheme not present
        if '://' not in endpoint:
            endpoint = 'http://' + endpoint

        scheme, netloc, path, query, fragment = urlparse.urlsplit(endpoint)

        if not path or path == "/":
            path = "/lookup"

        params = parse_qs(query)
        params['topic'] = self.topic
        query = urlencode(_utf8_params(params), doseq=1)
        lookupd_url = urlparse.urlunsplit((scheme, netloc, path, query, fragment))

        req = tornado.httpclient.HTTPRequest(
            lookupd_url, method='GET',
            headers={'Accept': 'application/vnd.nsq; version=1.0'},
            connect_timeout=self.lookupd_connect_timeout,
            request_timeout=self.lookupd_request_timeout)

        try:
            response = yield self.http_client.fetch(req)
        except Exception as e:
            logger.warning('[%s] lookupd %s query error: %s',
                           self.name, lookupd_url, e)
            return
        try:
            lookup_data = json.loads(response.body.decode("utf8"))
        except ValueError:
            logger.warning('[%s] lookupd %s failed to parse JSON: %r',
                           self.name, lookupd_url, response.body)
            return

        for producer in lookup_data['producers']:
            # TODO: this can be dropped for 1.0
            address = producer.get('broadcast_address', producer.get('address'))
            assert address
            self.connect_to_nsqd(address, producer['tcp_port'])

    def set_max_in_flight(self, max_in_flight):
        """Dynamically adjust the reader max_in_flight. Set to 0 to immediately disable a Reader"""
        assert isinstance(max_in_flight, int)
        self.max_in_flight = max_in_flight

        if max_in_flight == 0:
            # set RDY 0 to all connections
            for conn in itervalues(self.conns):
                if conn.rdy > 0:
                    logger.debug('[%s:%s] rdy: %d -> 0', conn.id, self.name, conn.rdy)
                    self._send_rdy(conn, 0)
            self.total_rdy = 0
        else:
            self.need_rdy_redistributed = True
            self._redistribute_rdy_state()

    def _redistribute_rdy_state(self):
        # We redistribute RDY counts in a few cases:
        #
        # 1. our # of connections exceeds our configured max_in_flight
        # 2. we're in backoff mode (but not in a current backoff block)
        # 3. something out-of-band has set the need_rdy_redistributed flag (connection closed
        #    that was about to get RDY during backoff)
        #
        # At a high level, we're trying to mitigate stalls related to low-volume
        # producers when we're unable (by configuration or backoff) to provide a RDY count
        # of (at least) 1 to all of our connections.
        if not self.conns:
            return

        if self.disabled() or self.backoff_block or self.max_in_flight == 0:
            return

        if len(self.conns) > self.max_in_flight:
            self.need_rdy_redistributed = True
            logger.debug('redistributing RDY state (%d conns > %d max_in_flight)',
                         len(self.conns), self.max_in_flight)

        backoff_interval = self.backoff_timer.get_interval()
        if backoff_interval and len(self.conns) > 1:
            self.need_rdy_redistributed = True
            logger.debug('redistributing RDY state (%d backoff interval and %d conns > 1)',
                         backoff_interval, len(self.conns))

        if self.need_rdy_redistributed:
            self.need_rdy_redistributed = False

            # first set RDY 0 to all connections that have not received a message within
            # a configurable timeframe (low_rdy_idle_timeout).
            for conn_id, conn in iteritems(self.conns):
                last_message_duration = time.time() - conn.last_msg_timestamp
                logger.debug('[%s:%s] rdy: %d (last message received %.02fs)',
                             conn.id, self.name, conn.rdy, last_message_duration)
                if conn.rdy > 0 and last_message_duration > self.low_rdy_idle_timeout:
                    logger.info('[%s:%s] idle connection, giving up RDY count', conn.id, self.name)
                    self._send_rdy(conn, 0)

            conns = self.conns.values()

            in_flight_or_rdy = len([c for c in conns if c.in_flight or c.rdy])
            if backoff_interval:
                available_rdy = max(0, 1 - in_flight_or_rdy)
            else:
                available_rdy = max(0, self.max_in_flight - in_flight_or_rdy)

            # if moving any connections from RDY 0 to non-0 would violate in-flight constraints,
            # set RDY 0 on some connection with msgs in flight so that a later redistribution
            # round can proceed and we don't stay pinned to the same connections.
            #
            # if nothing's in flight, then we have connections with RDY 1 that are still
            # waiting to hit the idle timeout, in which case it's ok to do nothing.
            in_flight = [c for c in conns if c.in_flight]
            if in_flight and not available_rdy:
                conn = random.choice(in_flight)
                logger.info('[%s:%s] too many msgs in flight, giving up RDY count',
                            conn.id, self.name)
                self._send_rdy(conn, 0)

            # randomly walk the list of possible connections and send RDY 1 (up to our
            # calculated "max_in_flight").  We only need to send RDY 1 because in both
            # cases described above your per connection RDY count would never be higher.
            #
            # We also don't attempt to avoid the connections who previously might have had RDY 1
            # because it would be overly complicated and not actually worth it (ie. given enough
            # redistribution rounds it doesn't matter).
            possible_conns = [c for c in conns if not (c.in_flight or c.rdy)]
            while possible_conns and available_rdy:
                available_rdy -= 1
                conn = possible_conns.pop(random.randrange(len(possible_conns)))
                logger.info('[%s:%s] redistributing RDY', conn.id, self.name)
                self._send_rdy(conn, 1)

            # for tests
            return conn

    #
    # subclass overwriteable
    #

    def process_message(self, message):
        """
        Called when a message is received in order to execute the configured ``message_handler``

        This is useful to subclass and override if you want to change how your
        message handlers are called.

        :param message: the :class:`nsq.Message` received
        """
        return self.message_handler(message)

    def giving_up(self, message):
        """
        Called when a message has been received where ``msg.attempts > max_tries``

        This is useful to subclass and override to perform a task (such as writing to disk, etc.)

        :param message: the :class:`nsq.Message` received
        """
        logger.warning('[%s] giving up on message %s after %d tries (max:%d) %r',
                       self.name, message.id, message.attempts, self.max_tries, message.body)

    def _on_connection_identify_response(self, conn, data, **kwargs):
        if not hasattr(self, '_disabled_notice'):
            self._disabled_notice = True

            def semver(v):
                def cast(x):
                    try:
                        return int(x)
                    except Exception:
                        return x
                return [cast(x) for x in v.replace('-', '.').split('.')]

            if self.disabled.__code__ != Reader.disabled.__code__ and \
               semver(data['version']) >= semver('0.3'):
                warnings.warn('disabled() is deprecated and will be removed in a future release, '
                              'use set_max_in_flight(0) instead', DeprecationWarning)
        return super(Reader, self)._on_connection_identify_response(conn, data, **kwargs)

    @classmethod
    def disabled(cls):
        """
        Called as part of RDY handling to identify whether this Reader has been disabled

        This is useful to subclass and override to examine a file on disk or a key in cache
        to identify if this reader should pause execution (during a deploy, etc.).

        Note: deprecated. Use set_max_in_flight(0)
        """
        return False

    def validate_message(self, message):
        return True

    def preprocess_message(self, message):
        return message
Beispiel #56
0
class ApsCapture(object):
    def __init__(self, capture_interface, control_socket,
                 mkrecv_config_filename, mkrecv_cpu_set, apsuse_cpu_set,
                 sensor_prefix, dada_key):
        log.info("Building ApsCapture instance with parameters: ({})".format(
            capture_interface, control_socket, mkrecv_config_filename,
            mkrecv_cpu_set, apsuse_cpu_set, sensor_prefix, dada_key))
        self._capture_interface = capture_interface
        self._control_socket = control_socket
        self._mkrecv_config_filename = mkrecv_config_filename
        self._mkrecv_cpu_set = mkrecv_cpu_set
        self._apsuse_cpu_set = apsuse_cpu_set
        self._sensor_prefix = sensor_prefix
        self._dada_input_key = dada_key
        self._mkrecv_proc = None
        self._apsuse_proc = None
        self._dada_db_proc = None
        self._ingress_buffer_monitor = None
        self._internal_beam_mapping = {}
        self._sensors = []
        self._capturing = False
        self.ioloop = IOLoop.current()
        self.setup_sensors()

    def add_sensor(self, sensor):
        sensor.name = "{}-{}".format(self._sensor_prefix, sensor.name)
        self._sensors.append(sensor)

    def setup_sensors(self):
        self._config_sensor = Sensor.string(
            "configuration",
            description="The current configuration of the capture instance",
            default="",
            initial_status=Sensor.UNKNOWN)
        self.add_sensor(self._config_sensor)

        self._mkrecv_header_sensor = Sensor.string(
            "mkrecv-capture-header",
            description=
            "The MKRECV/DADA header used for configuring capture with MKRECV",
            default="",
            initial_status=Sensor.UNKNOWN)
        self.add_sensor(self._mkrecv_header_sensor)

        self._apsuse_args_sensor = Sensor.string(
            "apsuse-arguments",
            description="The command line arguments used to invoke apsuse",
            default="",
            initial_status=Sensor.UNKNOWN)
        self.add_sensor(self._apsuse_args_sensor)

        self._mkrecv_heap_loss = Sensor.float(
            "fbf-heap-loss",
            description=("The percentage of FBFUSE heaps lost "
                         "(within MKRECV statistics window)"),
            default=0.0,
            initial_status=Sensor.UNKNOWN,
            unit="%")
        self.add_sensor(self._mkrecv_heap_loss)

        self._ingress_buffer_percentage = Sensor.float(
            "ingress-buffer-fill-level",
            description=("The percentage fill level for the capture"
                         "buffer between MKRECV and APSUSE"),
            default=0.0,
            initial_status=Sensor.UNKNOWN,
            unit="%")
        self.add_sensor(self._ingress_buffer_percentage)

    @coroutine
    def _start_db(self, key, block_size, nblocks, timeout=100.0):
        log.debug(("Building DADA buffer: key={}, block_size={}, "
                   "nblocks={}").format(key, block_size, nblocks))
        cmdline = map(str, [
            "dada_db", "-k", key, "-b", block_size, "-n", nblocks, "-l", "-p",
            "-w"
        ])
        self._dada_db_proc = Popen(cmdline,
                                   stdout=PIPE,
                                   stderr=PIPE,
                                   shell=False,
                                   close_fds=True)
        start = time.time()
        while psutil.virtual_memory().cached < block_size * nblocks:
            log.info("Cached: {} bytes, require {} bytes".format(
                psutil.virtual_memory().cached, block_size * nblocks))
            yield sleep(1.0)
            if time.time() - start > timeout:
                raise Exception(
                    "Caching of DADA buffer took longer than {} seconds".
                    format(timeout))
        log.info("Took {} seconds to allocate {}x{} GB DADA buffer".format(
            time.time() - start, block_size / 1e9, nblocks))

    def _stop_db(self):
        log.debug("Destroying DADA buffer")
        if self._dada_db_proc is not None:
            self._dada_db_proc.terminate()
            self._dada_db_proc.wait()
        self._dada_db_proc = None

    @coroutine
    def capture_start(self, config):
        log.info("Preparing apsuse capture instance")
        log.info("Config: {}".format(config))
        nbeams = len(config['beam-ids'])
        npartitions = config['nchans'] / config['nchans-per-heap']
        # desired beams here is a list of beam IDs, e.g. [1,2,3,4,5]
        heap_group_size = config['heap-size'] * nbeams * npartitions

        #heap_group_duration = (heap_group_size / config['nchans-per-heap']) * config['sampling-interval']
        #optimal_heap_groups = int(OPTIMAL_BLOCK_LENGTH / heap_group_duration)
        #if (optimal_heap_groups * heap_group_size) > MAX_DADA_BLOCK_SIZE:
        ngroups_data = int(MAX_DADA_BLOCK_SIZE / heap_group_size)
        #else:
        #    ngroups_data = optimal_heap_groups

        # Move to power of 2 heap groups (not necessary, but helpful)
        ngroups_data = 2**((ngroups_data - 1).bit_length())

        # Make sure at least 8 groups used
        ngroups_data = max(ngroups_data, 8)

        # Make DADA buffer and start watchers
        log.info("Creating capture buffer")
        capture_block_size = ngroups_data * heap_group_size
        if (capture_block_size * OPTIMAL_CAPTURE_BLOCKS >
                AVAILABLE_CAPTURE_MEMORY):
            capture_block_count = int(AVAILABLE_CAPTURE_MEMORY /
                                      capture_block_size)
            if capture_block_count < 3:
                raise Exception("Cannot allocate more than 2 capture blocks")
        else:
            capture_block_count = OPTIMAL_CAPTURE_BLOCKS

        log.debug("Creating dada buffer for input with key '{}'".format(
            "%s" % self._dada_input_key))
        yield self._start_db(self._dada_input_key, capture_block_size,
                             capture_block_count)
        log.info("Capture buffer ready")
        self._config_sensor.set_value(json.dumps(config))
        idx = 0
        for beam in config['beam-ids']:
            self._internal_beam_mapping[beam] = idx
            idx += 1

        # Start APSUSE processing code
        apsuse_cmdline = [
            "taskset", "-c", self._apsuse_cpu_set, "apsuse", "--input_key",
            self._dada_input_key, "--ngroups", ngroups_data, "--nbeams",
            nbeams, "--nchannels", config['nchans-per-heap'], "--nsamples",
            config['heap-size'] / config['nchans-per-heap'], "--nfreq",
            npartitions, "--size",
            int(config['filesize']), "--socket", self._control_socket, "--dir",
            config["base-output-dir"], "--log_level", "info"
        ]
        log.info("Starting APSUSE")
        log.debug(" ".join(map(str, apsuse_cmdline)))
        self._apsuse_proc = ManagedProcess(apsuse_cmdline,
                                           stdout_handler=log.debug,
                                           stderr_handler=log.error)
        self._apsuse_args_sensor.set_value(" ".join(map(str, apsuse_cmdline)))
        yield sleep(5)

        def make_beam_list(indices):
            spec = ""
            for a, b in itertools.groupby(enumerate(indices),
                                          lambda pair: pair[1] - pair[0]):
                if spec != "":
                    spec += ","
                b = list(b)
                p, q = b[0][1], b[-1][1]
                if p == q:
                    spec += "{}".format(p)
                else:
                    spec += "{}:{}".format(p, q + 1)
            return spec

        # Start MKRECV capture code
        mkrecv_config = {
            'dada_mode':
            4,
            'dada_key':
            self._dada_input_key,
            'bandwidth':
            config['bandwidth'],
            'centre_frequency':
            config['centre-frequency'],
            'nchannels':
            config["nchans"],
            'sampling_interval':
            config['sampling-interval'],
            'sync_epoch':
            config['sync-epoch'],
            'sample_clock':
            config['sample-clock'],
            'mcast_sources':
            ",".join(config['mcast-groups']),
            'nthreads':
            len(config['mcast-groups']) + 1,
            'mcast_port':
            str(config['mcast-port']),
            'interface':
            self._capture_interface,
            'timestamp_step':
            config['idx1-step'],
            'timestamp_modulus':
            1,
            'beam_ids_csv':
            make_beam_list(config['stream-indices']),
            'freq_ids_csv':
            "0:{}:{}".format(config['nchans'], config['nchans-per-heap']),
            'heap_size':
            config['heap-size']
        }
        mkrecv_header = make_mkrecv_header(
            mkrecv_config, outfile=self._mkrecv_config_filename)
        log.info("Determined MKRECV configuration:\n{}".format(mkrecv_header))
        self._mkrecv_header_sensor.set_value(mkrecv_header)

        def update_heap_loss_sensor(curr, total, avg, window):
            self._mkrecv_heap_loss.set_value(100.0 - avg)

        mkrecv_sensor_updater = MkrecvStdoutHandler(
            callback=update_heap_loss_sensor)

        def mkrecv_aggregated_output_handler(line):
            log.debug(line)
            mkrecv_sensor_updater(line)

        log.info("Starting MKRECV")
        self._mkrecv_proc = ManagedProcess(
            [
                "taskset", "-c", self._mkrecv_cpu_set, "mkrecv_rnt",
                "--header", self._mkrecv_config_filename, "--quiet"
            ],
            stdout_handler=mkrecv_aggregated_output_handler,
            stderr_handler=log.error)
        yield sleep(5)

        def exit_check_callback():
            if not self._mkrecv_proc.is_alive():
                log.error("mkrecv_nt exited unexpectedly")
                self.ioloop.add_callback(self.capture_stop)
            elif not self._apsuse_proc.is_alive():
                log.error("apsuse pipeline exited unexpectedly")
                self.ioloop.add_callback(self.capture_stop)
            self._capture_monitor.stop()

        self._capture_monitor = PeriodicCallback(exit_check_callback, 1000)
        self._capture_monitor.start()
        self._ingress_buffer_monitor = DbMonitor(
            self._dada_input_key,
            callback=lambda params: self._ingress_buffer_percentage.set_value(
                params["fraction-full"]))
        self._ingress_buffer_monitor.start()
        self._capturing = True
        log.info("Successfully started capture pipeline")

    def target_start(self, beam_info, output_dir):
        # Send target information to apsuse pipeline
        # and trigger file writing

        # First build message containing beam information
        # in JSON form:
        #
        # {
        # "command":"start",
        # "beam_parameters": [
        #     {id: "cfbf00000", name: "PSRJ1823+3410", "ra": "00:00:00.00", "dec": "00:00:00.00"},
        #     {id: "cfbf00002", name: "SBGS0000", "ra": "00:00:00.00", "dec": "00:00:00.00"},
        #     {id: "cfbf00006", name: "SBGS0000", "ra": "00:00:00.00", "dec": "00:00:00.00"},
        #     {id: "cfbf00008", name: "SBGS0000", "ra": "00:00:00.00", "dec": "00:00:00.00"},
        #     {id: "cfbf00010", name: "SBGS0000", "ra": "00:00:00.00", "dec": "00:00:00.00"}
        # ]
        # }
        #
        # Here the "idx" parameter refers to the internal index of the beam, e.g. if the
        # apsuse executable is handling 6 beams these are numbered 0-5 regardless of their
        # global index. It is thus necessary to track the mapping between internal and
        # external indices for these beams.
        #
        log.info("Target start on capture instance")
        beam_params = []
        message_dict = {
            "command": "start",
            "directory": output_dir,
            "beam_parameters": beam_params
        }
        log.info("Parsing beam information")
        for beam, target_str in beam_info.items():
            if beam in self._internal_beam_mapping:
                idx = self._internal_beam_mapping[beam]
                target = Target(target_str)
                ra, dec = map(str, target.radec())
                log.info(
                    "IDX: {}, name: {}, ra: {}, dec: {}, source: {}".format(
                        idx, beam, ra, dec, target.name))
                beam_params.append({
                    "idx": idx,
                    "name": beam,
                    "source": target.name,
                    "ra": ra,
                    "dec": dec
                })
        log.debug("Connecting to apsuse instance via socket")
        client = UDSClient(self._control_socket)
        log.debug("Sending message: {}".format(json.dumps(message_dict)))
        client.send(json.dumps(message_dict))
        response_str = client.recv(timeout=3)
        try:
            response = json.loads(response_str)["response"]
        except Exception:
            log.exception(
                "Unable to parse JSON returned from apsuse application")
        else:
            log.debug("Response: {}".format(response_str))
            if response != "success":
                raise Exception("Failed to start APSUSE recording")
        finally:
            client.close()
        log.debug("Closed socket connection")

    def target_stop(self):
        # Trigger end of file writing
        # First build JSON message to trigger end.
        # Message has the form:
        # {
        #     "command": "stop"
        # }
        log.info("Target stop request on capture instance")
        message = {"command": "stop"}
        log.debug("Connecting to apsuse instance via socket")
        client = UDSClient(self._control_socket)
        log.debug("Sending message: {}".format(json.dumps(message)))
        client.send(json.dumps(message))
        response_str = client.recv(timeout=3)
        try:
            response = json.loads(response_str)["response"]
        except Exception:
            log.exception(
                "Unable to parse JSON returned from apsuse application")
        else:
            log.debug("Response: {}".format(response_str))
            if response != "success":
                raise Exception("Failed to stop APSUSE recording")
        finally:
            client.close()
            log.debug("Closed socket connection")

    @coroutine
    def capture_stop(self):
        log.info("Capture stop request on capture instance")
        self._capturing = False
        self._internal_beam_mapping = {}
        log.info("Stopping capture monitors")
        self._capture_monitor.stop()
        self._ingress_buffer_monitor.stop()
        log.info("Stopping MKRECV instance")
        self._mkrecv_proc.terminate()
        log.info("Stopping PSRDADA_CPP instance")
        self._apsuse_proc.terminate()
        log.info("Destroying DADA buffers")
        self._stop_db()
Beispiel #57
0
class UpdateManager:
    def __init__(self, config):
        self.server = config.get_server()
        self.config = config
        self.config.read_supplemental_config(SUPPLEMENTAL_CFG_PATH)
        self.repo_debug = config.getboolean('enable_repo_debug', False)
        auto_refresh_enabled = config.getboolean('enable_auto_refresh', False)
        self.distro = config.get('distro', "debian").lower()
        if self.distro not in SUPPORTED_DISTROS:
            raise config.error(f"Unsupported distro: {self.distro}")
        if self.repo_debug:
            logging.warn("UPDATE MANAGER: REPO DEBUG ENABLED")
        env = sys.executable
        mooncfg = self.config[f"update_manager static {self.distro} moonraker"]
        self.updaters = {
            "system": PackageUpdater(self),
            "moonraker": GitUpdater(self, mooncfg, MOONRAKER_PATH, env)
        }
        self.current_update = None
        # TODO: Check for client config in [update_manager].  This is
        # deprecated and will be removed.
        client_repo = config.get("client_repo", None)
        if client_repo is not None:
            client_path = config.get("client_path")
            name = client_repo.split("/")[-1]
            self.updaters[name] = WebUpdater(self, {
                'repo': client_repo,
                'path': client_path
            })
        client_sections = self.config.get_prefix_sections(
            "update_manager client")
        for section in client_sections:
            cfg = self.config[section]
            name = section.split()[-1]
            if name in self.updaters:
                raise config.error("Client repo named %s already added" %
                                   (name, ))
            client_type = cfg.get("type")
            if client_type == "git_repo":
                self.updaters[name] = GitUpdater(self, cfg)
            elif client_type == "web":
                self.updaters[name] = WebUpdater(self, cfg)
            else:
                raise config.error("Invalid type '%s' for section [%s]" %
                                   (client_type, section))

        # GitHub API Rate Limit Tracking
        self.gh_rate_limit = None
        self.gh_limit_remaining = None
        self.gh_limit_reset_time = None
        self.gh_init_evt = Event()
        self.cmd_request_lock = Lock()
        self.is_refreshing = False

        # Auto Status Refresh
        self.last_auto_update_time = 0
        self.refresh_cb = None
        if auto_refresh_enabled:
            self.refresh_cb = PeriodicCallback(self._handle_auto_refresh,
                                               UPDATE_REFRESH_INTERVAL_MS)
            self.refresh_cb.start()

        AsyncHTTPClient.configure(None, defaults=dict(user_agent="Moonraker"))
        self.http_client = AsyncHTTPClient()

        self.server.register_endpoint("/machine/update/moonraker", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/klipper", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/system", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/client", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/status", ["GET"],
                                      self._handle_status_request)

        # Register Ready Event
        self.server.register_event_handler("server:klippy_identified",
                                           self._set_klipper_repo)
        # Initialize GitHub API Rate Limits and configured updaters
        IOLoop.current().spawn_callback(self._initalize_updaters,
                                        list(self.updaters.values()))

    async def _initalize_updaters(self, initial_updaters):
        self.is_refreshing = True
        await self._init_api_rate_limit()
        for updater in initial_updaters:
            if isinstance(updater, PackageUpdater):
                ret = updater.refresh(False)
            else:
                ret = updater.refresh()
            if asyncio.iscoroutine(ret):
                await ret
        self.is_refreshing = False

    async def _set_klipper_repo(self):
        kinfo = self.server.get_klippy_info()
        if not kinfo:
            logging.info("No valid klippy info received")
            return
        kpath = kinfo['klipper_path']
        env = kinfo['python_path']
        kupdater = self.updaters.get('klipper', None)
        if kupdater is not None and kupdater.repo_path == kpath and \
                kupdater.env == env:
            # Current Klipper Updater is valid
            return
        kcfg = self.config[f"update_manager static {self.distro} klipper"]
        self.updaters['klipper'] = GitUpdater(self, kcfg, kpath, env)
        await self.updaters['klipper'].refresh()

    async def _check_klippy_printing(self):
        klippy_apis = self.server.lookup_plugin('klippy_apis')
        result = await klippy_apis.query_objects({'print_stats': None},
                                                 default={})
        pstate = result.get('print_stats', {}).get('state', "")
        return pstate.lower() == "printing"

    async def _handle_auto_refresh(self):
        if await self._check_klippy_printing():
            # Don't Refresh during a print
            logging.info("Klippy is printing, auto refresh aborted")
            return
        cur_time = time.time()
        cur_hour = time.localtime(cur_time).tm_hour
        time_diff = cur_time - self.last_auto_update_time
        # Update packages if it has been more than 12 hours
        # and the local time is between 12AM and 5AM
        if time_diff < MIN_REFRESH_TIME or cur_hour >= MAX_PKG_UPDATE_HOUR:
            # Not within the update time window
            return
        self.last_auto_update_time = cur_time
        vinfo = {}
        need_refresh_all = not self.is_refreshing
        async with self.cmd_request_lock:
            self.is_refreshing = True
            try:
                for name, updater in list(self.updaters.items()):
                    if need_refresh_all:
                        ret = updater.refresh()
                        if asyncio.iscoroutine(ret):
                            await ret
                    if hasattr(updater, "get_update_status"):
                        vinfo[name] = updater.get_update_status()
            except Exception:
                logging.exception("Unable to Refresh Status")
                return
            finally:
                self.is_refreshing = False
        uinfo = {
            'version_info': vinfo,
            'github_rate_limit': self.gh_rate_limit,
            'github_requests_remaining': self.gh_limit_remaining,
            'github_limit_reset_time': self.gh_limit_reset_time,
            'busy': self.current_update is not None
        }
        self.server.send_event("update_manager:update_refreshed", uinfo)

    async def _handle_update_request(self, web_request):
        if await self._check_klippy_printing():
            raise self.server.error("Update Refused: Klippy is printing")
        app = web_request.get_endpoint().split("/")[-1]
        if app == "client":
            app = web_request.get('name')
        inc_deps = web_request.get_boolean('include_deps', False)
        if self.current_update is not None and \
                self.current_update[0] == app:
            return f"Object {app} is currently being updated"
        updater = self.updaters.get(app, None)
        if updater is None:
            raise self.server.error(f"Updater {app} not available")
        async with self.cmd_request_lock:
            self.current_update = (app, id(web_request))
            try:
                await updater.update(inc_deps)
            except Exception as e:
                self.notify_update_response(f"Error updating {app}")
                self.notify_update_response(str(e), is_complete=True)
                raise
            finally:
                self.current_update = None
        return "ok"

    async def _handle_status_request(self, web_request):
        check_refresh = web_request.get_boolean('refresh', False)
        # Don't refresh if a print is currently in progress or
        # if an update is in progress.  Just return the current
        # state
        if self.current_update is not None or \
                await self._check_klippy_printing():
            check_refresh = False
        need_refresh = False
        if check_refresh:
            # If there is an outstanding request processing a
            # refresh, we don't need to do it again.
            need_refresh = not self.is_refreshing
            await self.cmd_request_lock.acquire()
            self.is_refreshing = True
        vinfo = {}
        try:
            for name, updater in list(self.updaters.items()):
                await updater.check_initialized(120.)
                if need_refresh:
                    ret = updater.refresh()
                    if asyncio.iscoroutine(ret):
                        await ret
                if hasattr(updater, "get_update_status"):
                    vinfo[name] = updater.get_update_status()
        except Exception:
            raise
        finally:
            if check_refresh:
                self.is_refreshing = False
                self.cmd_request_lock.release()
        return {
            'version_info': vinfo,
            'github_rate_limit': self.gh_rate_limit,
            'github_requests_remaining': self.gh_limit_remaining,
            'github_limit_reset_time': self.gh_limit_reset_time,
            'busy': self.current_update is not None
        }

    async def execute_cmd(self, cmd, timeout=10., notify=False, retries=1):
        shell_command = self.server.lookup_plugin('shell_command')
        cb = self.notify_update_response if notify else None
        scmd = shell_command.build_shell_command(cmd, callback=cb)
        while retries:
            if await scmd.run(timeout=timeout, verbose=notify):
                break
            retries -= 1
        if not retries:
            raise self.server.error("Shell Command Error")

    async def execute_cmd_with_response(self, cmd, timeout=10.):
        shell_command = self.server.lookup_plugin('shell_command')
        scmd = shell_command.build_shell_command(cmd, None)
        result = await scmd.run_with_response(timeout, retries=5)
        if result is None:
            raise self.server.error(f"Error Running Command: {cmd}")
        return result

    async def _init_api_rate_limit(self):
        url = "https://api.github.com/rate_limit"
        while 1:
            try:
                resp = await self.github_api_request(url, is_init=True)
                core = resp['resources']['core']
                self.gh_rate_limit = core['limit']
                self.gh_limit_remaining = core['remaining']
                self.gh_limit_reset_time = core['reset']
            except Exception:
                logging.exception("Error Initializing GitHub API Rate Limit")
                await tornado.gen.sleep(30.)
            else:
                reset_time = time.ctime(self.gh_limit_reset_time)
                logging.info(
                    "GitHub API Rate Limit Initialized\n"
                    f"Rate Limit: {self.gh_rate_limit}\n"
                    f"Rate Limit Remaining: {self.gh_limit_remaining}\n"
                    f"Rate Limit Reset Time: {reset_time}, "
                    f"Seconds Since Epoch: {self.gh_limit_reset_time}")
                break
        self.gh_init_evt.set()

    async def github_api_request(self, url, etag=None, is_init=False):
        if not is_init:
            timeout = time.time() + 30.
            try:
                await self.gh_init_evt.wait(timeout)
            except Exception:
                raise self.server.error("Timeout while waiting for GitHub "
                                        "API Rate Limit initialization")
        if self.gh_limit_remaining == 0:
            curtime = time.time()
            if curtime < self.gh_limit_reset_time:
                raise self.server.error(
                    f"GitHub Rate Limit Reached\nRequest: {url}\n"
                    f"Limit Reset Time: {time.ctime(self.gh_limit_remaining)}")
        headers = {"Accept": "application/vnd.github.v3+json"}
        if etag is not None:
            headers['If-None-Match'] = etag
        retries = 5
        while retries:
            try:
                timeout = time.time() + 10.
                fut = self.http_client.fetch(url,
                                             headers=headers,
                                             connect_timeout=5.,
                                             request_timeout=5.,
                                             raise_error=False)
                resp = await tornado.gen.with_timeout(timeout, fut)
            except Exception:
                retries -= 1
                msg = f"Error Processing GitHub API request: {url}"
                if not retries:
                    raise self.server.error(msg)
                logging.exception(msg)
                await tornado.gen.sleep(1.)
                continue
            etag = resp.headers.get('etag', None)
            if etag is not None:
                if etag[:2] == "W/":
                    etag = etag[2:]
            logging.info("GitHub API Request Processed\n"
                         f"URL: {url}\n"
                         f"Response Code: {resp.code}\n"
                         f"Response Reason: {resp.reason}\n"
                         f"ETag: {etag}")
            if resp.code == 403:
                raise self.server.error(
                    f"Forbidden GitHub Request: {resp.reason}")
            elif resp.code == 304:
                logging.info(f"Github Request not Modified: {url}")
                return None
            if resp.code != 200:
                retries -= 1
                if not retries:
                    raise self.server.error(
                        f"Github Request failed: {resp.code} {resp.reason}")
                logging.info(
                    f"Github request error, {retries} retries remaining")
                await tornado.gen.sleep(1.)
                continue
            # Update rate limit on return success
            if 'X-Ratelimit-Limit' in resp.headers and not is_init:
                self.gh_rate_limit = int(resp.headers['X-Ratelimit-Limit'])
                self.gh_limit_remaining = int(
                    resp.headers['X-Ratelimit-Remaining'])
                self.gh_limit_reset_time = float(
                    resp.headers['X-Ratelimit-Reset'])
            decoded = json.loads(resp.body)
            decoded['etag'] = etag
            return decoded

    async def http_download_request(self, url):
        retries = 5
        while retries:
            try:
                timeout = time.time() + 130.
                fut = self.http_client.fetch(
                    url,
                    headers={"Accept": "application/zip"},
                    connect_timeout=5.,
                    request_timeout=120.)
                resp = await tornado.gen.with_timeout(timeout, fut)
            except Exception:
                retries -= 1
                logging.exception("Error Processing Download")
                if not retries:
                    raise
                await tornado.gen.sleep(1.)
                continue
            return resp.body

    def notify_update_response(self, resp, is_complete=False):
        resp = resp.strip()
        if isinstance(resp, bytes):
            resp = resp.decode()
        notification = {
            'message': resp,
            'application': None,
            'proc_id': None,
            'complete': is_complete
        }
        if self.current_update is not None:
            notification['application'] = self.current_update[0]
            notification['proc_id'] = self.current_update[1]
        self.server.send_event("update_manager:update_response", notification)

    def close(self):
        self.http_client.close()
        if self.refresh_cb is not None:
            self.refresh_cb.stop()
Beispiel #58
0
class BlockingPool(object):
    """A connection pool that manages blocking PostgreSQL connections
    and cursors.

    :param min_conn: The minimum amount of connections that is created when a
                     connection pool is created.
    :param max_conn: The maximum amount of connections the connection pool can
                     have. If the amount of connections exceeds the limit a
                     ``PoolError`` exception is raised.
    :param cleanup_timeout: Time in seconds between pool cleanups. Connections
                            will be closed until there are ``min_conn`` left.
    :param database: The database name
    :param user: User name used to authenticate
    :param password: Password used to authenticate
    :param connection_factory: Using the connection_factory parameter a different
                               class or connections factory can be specified. It
                               should be a callable object taking a dsn argument.
    """
    def __init__(self,
                 min_conn=1,
                 max_conn=20,
                 cleanup_timeout=10,
                 *args,
                 **kwargs):
        self.min_conn = min_conn
        self.max_conn = max_conn
        self.closed = False

        self._args = args
        self._kwargs = kwargs

        self._pool = []

        for i in range(self.min_conn):
            self._new_conn()

        # Create a periodic callback that tries to close inactive connections
        if cleanup_timeout > 0:
            self._cleaner = PeriodicCallback(self._clean_pool,
                                             cleanup_timeout * 1000)
            self._cleaner.start()

    def _new_conn(self):
        """Create a new connection.
        """
        if len(self._pool) > self.max_conn:
            raise PoolError('connection pool exhausted')
        conn = psycopg2.connect(*self._args, **self._kwargs)
        self._pool.append(conn)

        return conn

    def _get_free_conn(self):
        """Look for a free connection and return it.

        `None` is returned when no free connection can be found.
        """
        if self.closed:
            raise PoolError('connection pool is closed')
        for conn in self._pool:
            if conn.status == STATUS_READY:
                return conn
        return None

    def get_connection(self):
        """Get a connection from the pool.

        If there's no free connection available, a new connection will be created.
        """
        connection = self._get_free_conn()
        if not connection:
            connection = self._new_conn()

        return connection

    def _clean_pool(self):
        """Close a number of inactive connections when the number of connections
        in the pool exceeds the number in `min_conn`.
        """
        if self.closed:
            raise PoolError('connection pool is closed')
        if len(self._pool) > self.min_conn:
            conns = len(self._pool) - self.min_conn
            for conn in self._pool[:]:
                if conn.status == STATUS_READY:
                    conn.close()
                    conns -= 1
                    self._pool.remove(conn)
                    if not conns:
                        break

    def close(self):
        """Close all open connections in the pool.
        """
        if self.closed:
            raise PoolError('connection pool is closed')
        for conn in self._pool:
            if not conn.closed:
                conn.close()
        self._cleaner.stop()
        self._pool = []
        self.closed = True
Beispiel #59
0
class AsyncPool(object):
    """A connection pool that manages asynchronous PostgreSQL connections
    and cursors.

    :param min_conn: The minimum amount of connections that is created when a
                     connection pool is created.
    :param max_conn: The maximum amount of connections the connection pool can
                     have. If the amount of connections exceeds the limit a
                     ``PoolError`` exception is raised.
    :param cleanup_timeout: Time in seconds between pool cleanups. Connections
                            will be closed until there are ``min_conn`` left.
    :param ioloop: An instance of Tornado's IOLoop.
    :param host: The database host address (defaults to UNIX socket if not provided)
    :param port: The database host port (defaults to 5432 if not provided)
    :param database: The database name
    :param user: User name used to authenticate
    :param password: Password used to authenticate
    :param connection_factory: Using the connection_factory parameter a different
                               class or connections factory can be specified. It
                               should be a callable object taking a dsn argument.
    """
    def __init__(self,
                 min_conn=1,
                 max_conn=20,
                 cleanup_timeout=10,
                 ioloop=None,
                 *args,
                 **kwargs):
        self.min_conn = min_conn
        self.max_conn = max_conn
        self.closed = False
        self._ioloop = ioloop or IOLoop.instance()
        self._args = args
        self._kwargs = kwargs
        self._last_reconnect = 0

        self._pool = []

        for i in range(self.min_conn):
            self._new_conn()
        self._last_reconnect = time.time()

        # Create a periodic callback that tries to close inactive connections
        if cleanup_timeout > 0:
            self._cleaner = PeriodicCallback(self._clean_pool,
                                             cleanup_timeout * 1000)
            self._cleaner.start()

    def _new_conn(self, callback=None, callback_args=[]):
        """Create a new connection.

        :param callback_args: Parameters for the callback - connection will be appended
        to the parameters
        """
        if len(self._pool) > self.max_conn:
            self._clean_pool()
        if len(self._pool) > self.max_conn:
            raise PoolError('connection pool exhausted')
        timeout = self._last_reconnect + .25  # 1/4 second delay between reconnection
        timenow = time.time()
        if timenow > timeout or len(self._pool) <= self.min_conn:
            self._last_reconnect = timenow
            conn = AsyncConnection(self._ioloop)
            callbacks = [partial(self._pool.append,
                                 conn)]  # add new connection to the pool
            if callback:
                callbacks.append(partial(callback, *(callback_args + [conn])))

            conn.open(callbacks, *self._args, **self._kwargs)
        else:
            # recursive timeout call, retaining the parameters
            self._ioloop.add_timeout(
                timeout, partial(self._new_conn, callback, callback_args))

    def _get_free_conn(self):
        """Look for a free connection and return it.

        `None` is returned when no free connection can be found.
        """
        if self.closed:
            raise PoolError('connection pool is closed')
        for conn in self._pool:
            if not conn.isexecuting():
                return conn
        return None

    def get_connection(self, callback=None, callback_args=[]):
        """Get a connection, trying available ones, and if not available - create a new one;

        Afterwards, the callback will be called
        """
        connection = self._get_free_conn()
        if connection is None:
            self._new_conn(callback, callback_args)
        else:
            callback(*(callback_args + [connection]))

    def new_cursor(self,
                   function,
                   function_args=(),
                   callback=None,
                   cursor_kwargs={},
                   connection=None,
                   transaction=False):
        """Create a new cursor.

        If there's no connection available, a new connection will be created and
        `new_cursor` will be called again after the connection has been made.

        :param function: ``execute``, ``executemany`` or ``callproc``.
        :param function_args: A tuple with the arguments for the specified function.
        :param callback: A callable that is executed once the operation is done.
        :param cursor_kwargs: A dictionary with Psycopg's `connection.cursor`_ arguments.
        :param connection: An ``AsyncConnection`` connection. Optional.

        .. _connection.cursor: http://initd.org/psycopg/docs/connection.html#connection.cursor
        """
        if connection is not None:
            try:
                connection.cursor(function, function_args, callback,
                                  cursor_kwargs)
                return
            except (DatabaseError,
                    InterfaceError):  # Recover from lost connection
                logging.warning('Requested connection was closed')
                self._pool.remove(connection)

        # if no connection, or if exception caught
        if not transaction:
            self.get_connection(callback=self.new_cursor,
                                callback_args=[
                                    function, function_args, callback,
                                    cursor_kwargs
                                ])
        else:
            raise TransactionError

    def _clean_pool(self):
        """Close a number of inactive connections when the number of connections
        in the pool exceeds the number in `min_conn`.
        """
        if self.closed:
            raise PoolError('connection pool is closed')
        if len(self._pool) > self.min_conn:
            conns = len(self._pool) - self.min_conn
            for conn in self._pool[:]:
                if not conn.isexecuting():
                    conn.close()
                    conns -= 1
                    self._pool.remove(conn)
                    if not conns:
                        break

    def close(self):
        """Close all open connections in the pool.
        """
        if self.closed:
            raise PoolError('connection pool is closed')
        for conn in self._pool:
            if not conn.closed:
                conn.close()
        self._cleaner.stop()
        self._pool = []
        self.closed = True
Beispiel #60
0
class Events(threading.Thread):
    events_enable_interval = 5000

    def __init__(self,
                 capp,
                 db=None,
                 persistent=False,
                 enable_events=True,
                 io_loop=None,
                 state_save_interval=0,
                 **kwargs):
        threading.Thread.__init__(self)
        self.daemon = True

        self.io_loop = io_loop or IOLoop.instance()
        self.capp = capp

        self.db = db
        self.persistent = persistent
        self.enable_events = enable_events
        self.state = None
        self.state_save_timer = None

        if self.persistent:
            logger.debug("Loading state from '%s'...", self.db)
            state = shelve.open(self.db)
            if state:
                self.state = state['events']
            state.close()

            if state_save_interval:
                self.state_save_timer = PeriodicCallback(
                    self.save_state, state_save_interval)

        if not self.state:
            self.state = EventsState(**kwargs)

        self.timer = PeriodicCallback(self.on_enable_events,
                                      self.events_enable_interval)

    def start(self):
        threading.Thread.start(self)
        if self.enable_events:
            logger.debug("Starting enable events timer...")
            self.timer.start()

        if self.state_save_timer:
            logger.debug("Starting state save timer...")
            self.state_save_timer.start()

    def stop(self):
        if self.enable_events:
            logger.debug("Stopping enable events timer...")
            self.timer.stop()

        if self.state_save_timer:
            logger.debug("Stopping state save timer...")
            self.state_save_timer.stop()

        if self.persistent:
            self.save_state()

    def run(self):
        try_interval = 1
        while True:
            try:
                try_interval *= 2

                with self.capp.connection() as conn:
                    recv = EventReceiver(conn,
                                         handlers={"*": self.on_event},
                                         app=self.capp)
                    try_interval = 1
                    logger.debug("Capturing events...")
                    recv.capture(limit=None, timeout=None, wakeup=True)
            except (KeyboardInterrupt, SystemExit):
                try:
                    import _thread as thread
                except ImportError:
                    import thread
                thread.interrupt_main()
            except Exception as e:
                logger.error(
                    "Failed to capture events: '%s', "
                    "trying again in %s seconds.", e, try_interval)
                logger.debug(e, exc_info=True)
                time.sleep(try_interval)

    def save_state(self):
        logger.debug("Saving state to '%s'...", self.db)
        state = shelve.open(self.db)
        state['events'] = self.state
        state.close()

    def on_enable_events(self):
        # Periodically enable events for workers
        # launched after flower
        self.io_loop.run_in_executor(None, self.capp.control.enable_events)

    def on_event(self, event):
        # Call EventsState.event in ioloop thread to avoid synchronization
        self.io_loop.add_callback(partial(self.state.event, event))