def __init__(self, ip_addrs=[], node_port=51348, listen_port=0, scheduler_node=None, scheduler_port=51347): addrinfos = [] if not ip_addrs: ip_addrs = [None] for i in range(len(ip_addrs)): ip_addr = ip_addrs[i] addrinfo = dispy.node_addrinfo(ip_addr) if not addrinfo: logger.warning('Ignoring invalid ip_addr %s', ip_addr) continue addrinfos.append(addrinfo) if not listen_port: listen_port = node_port self.node_port = node_port self.listen_port = listen_port self.scheduler_port = scheduler_port self.scheduler_ip_addrs = list( filter(lambda ip: bool(ip), [dispy._node_ipaddr(scheduler_node)])) for addrinfo in addrinfos: self.listen_udp_task = Task(self.listen_udp_proc, addrinfo) self.listen_tcp_task = Task(self.listen_tcp_proc, addrinfo) self.sched_udp_task = Task(self.sched_udp_proc, addrinfo) logger.info('version %s started', dispy._dispy_version)
def __init__(self, DocumentRoot, secret='', http_host='localhost', poll_interval=60, ping_interval=600, hosts=[], ipv4_udp_multicast=False, certfile=None, keyfile=None): http_port = dispy.config.HTTPServerPort self.node_port = eval(dispy.config.NodePort) self.info_port = eval(dispy.config.ClientPort) self.lock = threading.Lock() self.client_uid = None self.client_uid_time = 0 self.nodes = {} self.updates = {} if poll_interval < 1: logger.warning('invalid poll_interval value %s; it must be at least 1', poll_interval) poll_interval = 1 self.poll_interval = poll_interval self.ping_interval = ping_interval self.secret = secret self.keyfile = keyfile self.certfile = certfile self.ipv4_udp_multicast = bool(ipv4_udp_multicast) self.addrinfos = [] if not hosts: hosts = [None] for host in hosts: addrinfo = dispy.host_addrinfo(host=host, ipv4_multicast=self.ipv4_udp_multicast) if not addrinfo: logger.warning('Ignoring invalid host %s', host) continue self.addrinfos.append(addrinfo) if not self.addrinfos: raise Exception('No valid host name / IP address found') self.sign = hashlib.sha1(os.urandom(20)) for addrinfo in self.addrinfos: self.sign.update(addrinfo.ip.encode()) self.sign = self.sign.hexdigest() self.auth = dispy.auth_code(self.secret, self.sign) self.tcp_tasks = [] self.udp_tasks = [] udp_addrinfos = {} for addrinfo in self.addrinfos: self.tcp_tasks.append(Task(self.tcp_server, addrinfo)) udp_addrinfos[addrinfo.bind_addr] = addrinfo for bind_addr, addrinfo in udp_addrinfos.items(): self.udp_tasks.append(Task(self.udp_server, addrinfo)) self._server = HTTPServer((http_host, http_port), lambda *args: self.__class__._HTTPRequestHandler(self, DocumentRoot, *args)) if certfile: self._server.socket = ssl.wrap_socket(self._server.socket, keyfile=keyfile, certfile=certfile, server_side=True) self.timer = Task(self.timer_proc) self._httpd_thread = threading.Thread(target=self._server.serve_forever) self._httpd_thread.daemon = True self._httpd_thread.start() self.client_host = self._server.socket.getsockname()[0] logger.info('Started HTTP%s server at %s:%s', 's' if certfile else '', self.client_host, self._server.socket.getsockname()[1])
def communicate(self, input=None): """Similar to Popen's communicate. Must be used with 'yield' as 'stdout, stderr = yield async_pipe.communicate()' 'input' must be either data or an object with 'read' method (i.e., regular file object or AsyncFile object). """ def write_proc(fd, input, task=None): size = 16384 if isinstance(input, str) or isinstance(input, bytes): n = yield fd.write(input, full=True) if n != len(input): raise IOError('write failed') else: # TODO: how to know if 'input' is file object for # on-disk file? if hasattr(input, 'seek') and hasattr(input, 'fileno'): read_func = partial_func(os.read, input.fileno()) else: read_func = input.read while 1: data = yield read_func(size) if not data: break if isinstance(data, str): data = data.encode() n = yield fd.write(data, full=True) if n != len(data): raise IOError('write failed') input.close() fd.close() def read_proc(fd, task=None): size = 16384 buflist = [] while 1: buf = yield fd.read(size) if not buf: break buflist.append(buf) fd.close() data = b''.join(buflist) raise StopIteration(data) if self.stdout: stdout_task = Task(read_proc, self.stdout) if self.stderr: stderr_task = Task(read_proc, self.stderr) if input and self.stdin: stdin_task = Task(write_proc, self.stdin, input) yield stdin_task.finish() out, err = ((yield stdout_task.finish()) if self.stdout else None, (yield stderr_task.finish()) if self.stderr else None) # TODO: Is it possible for 'wait' to block even after I/O is finished? self.wait() raise StopIteration((out, err))
def __init__(self, ip_addrs=[], relay_port=0, scheduler_nodes=[], scheduler_port=0, ipv4_udp_multicast=False, secret='', certfile=None, keyfile=None): self.ipv4_udp_multicast = bool(ipv4_udp_multicast) addrinfos = [] if not ip_addrs: ip_addrs = [None] for i in range(len(ip_addrs)): ip_addr = ip_addrs[i] addrinfo = dispy.host_addrinfo(host=ip_addr, ipv4_multicast=self.ipv4_udp_multicast) if not addrinfo: logger.warning('Ignoring invalid ip_addr %s', ip_addr) continue addrinfos.append(addrinfo) self.node_port = eval(dispy.config.NodePort) self.scheduler_port = scheduler_port self.relay_port = relay_port self.ip_addrs = set() self.scheduler_ip_addr = None self.secret = secret if certfile: self.certfile = os.path.abspath(certfile) else: self.certfile = None if keyfile: self.keyfile = os.path.abspath(keyfile) else: self.keyfile = None udp_addrinfos = {} for addrinfo in addrinfos: self.ip_addrs.add(addrinfo.ip) Task(self.relay_tcp_proc, addrinfo) udp_addrinfos[addrinfo.bind_addr] = addrinfo scheduler_ip_addrs = [] for addr in scheduler_nodes: addr = dispy._node_ipaddr(addr) if addr: scheduler_ip_addrs.append(addr) for bind_addr, addrinfo in udp_addrinfos.items(): Task(self.relay_udp_proc, bind_addr, addrinfo) Task(self.sched_udp_proc, bind_addr, addrinfo) for addr in scheduler_ip_addrs: msg = {'version': __version__, 'ip_addrs': [addr], 'port': self.scheduler_port, 'sign': None} Task(self.verify_broadcast, addrinfo, msg) logger.info('version %s started', dispy._dispy_version)
def relay_tcp_proc(self, addrinfo, task=None): task.set_daemon() auth_len = len(dispy.auth_code('', '')) tcp_sock = AsyncSocket(socket.socket(addrinfo.family, socket.SOCK_STREAM), keyfile=self.keyfile, certfile=self.certfile) tcp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) tcp_sock.bind((addrinfo.ip, self.relay_port)) tcp_sock.listen(8) def tcp_req(conn, addr, task=None): conn.settimeout(dispy.MsgTimeout) try: msg = yield conn.recvall(auth_len) msg = yield conn.recv_msg() except: logger.debug(traceback.format_exc()) logger.debug('Ignoring invalid TCP message from %s:%s', addr[0], addr[1]) raise StopIteration finally: conn.close() try: msg = deserialize(msg[len('PING:'.encode()):]) if msg['version'] != __version__: logger.warning('Ignoring %s due to version mismatch: %s / %s', msg['ip_addrs'], msg['version'], __version__) raise StopIteration except: logger.debug('Ignoring ping message from %s (%s)', addr[0], addr[1]) logger.debug(traceback.format_exc()) raise StopIteration Task(self.verify_broadcast, addrinfo, msg) while 1: conn, addr = yield tcp_sock.accept() Task(tcp_req, conn, addr)
def __init__(self, DocumentRoot, secret='', http_host='localhost', http_port=8181, info_port=51347, node_port=51348, poll_interval=60, ping_interval=600, ip_addrs=[], ipv4_udp_multicast=False, certfile=None, keyfile=None): self.lock = threading.Lock() self.client_uid = None self.client_uid_time = 0 self.nodes = {} self.updates = {} if poll_interval < 1: logger.warning('invalid poll_interval value %s; it must be at least 1', poll_interval) poll_interval = 1 self.info_port = info_port self.poll_interval = poll_interval self.ping_interval = ping_interval self.secret = secret self.node_port = node_port self.keyfile = keyfile self.certfile = certfile self.ipv4_udp_multicast = bool(ipv4_udp_multicast) self.addrinfos = {} if not ip_addrs: ip_addrs = [None] for i in range(len(ip_addrs)): ip_addr = ip_addrs[i] addrinfo = dispy.host_addrinfo(host=ip_addr, ipv4_multicast=self.ipv4_udp_multicast) if not addrinfo: logger.warning('Ignoring invalid ip_addr %s', ip_addr) continue self.addrinfos[addrinfo.ip] = addrinfo if not self.addrinfos: raise Exception('No valid IP address found') self.http_port = http_port self.sign = hashlib.sha1(os.urandom(20)) for ip_addr in self.addrinfos: self.sign.update(ip_addr.encode()) self.sign = self.sign.hexdigest() self.auth = dispy.auth_code(self.secret, self.sign) self.tcp_tasks = [] self.udp_tasks = [] udp_addrinfos = {} for addrinfo in self.addrinfos.values(): self.tcp_tasks.append(Task(self.tcp_server, addrinfo)) udp_addrinfos[addrinfo.bind_addr] = addrinfo for bind_addr, addrinfo in udp_addrinfos.items(): self.udp_tasks.append(Task(self.udp_server, addrinfo)) self._server = HTTPServer((http_host, http_port), lambda *args: self.__class__._HTTPRequestHandler(self, DocumentRoot, *args)) if certfile: self._server.socket = ssl.wrap_socket(self._server.socket, keyfile=keyfile, certfile=certfile, server_side=True) self.timer = Task(self.timer_proc) self._httpd_thread = threading.Thread(target=self._server.serve_forever) self._httpd_thread.daemon = True self._httpd_thread.start() logger.info('Started HTTP%s server at %s', 's' if certfile else '', ':'.join(map(str, self._server.socket.getsockname())))
def tcp_req(conn, addr, task=None): conn.settimeout(dispy.MsgTimeout) try: msg = yield conn.recvall(auth_len) msg = yield conn.recv_msg() except: logger.debug(traceback.format_exc()) logger.debug('Ignoring invalid TCP message from %s:%s', addr[0], addr[1]) raise StopIteration finally: conn.close() try: msg = deserialize(msg[len('PING:'.encode()):]) if msg['version'] != __version__: logger.warning( 'Ignoring %s due to version mismatch: %s / %s', msg['ip_addrs'], msg['version'], __version__) raise StopIteration except: logger.debug('Ignoring ping message from %s (%s)', addr[0], addr[1]) logger.debug(traceback.format_exc()) raise StopIteration Task(self.verify_broadcast, addrinfo, msg)
def set_secret(self, secret, task=None): with self.lock: self.secret = secret for node in self.nodes.values(): if not node._priv.auth: Task(self.get_node_info, node) self.timer.resume()
def tcp_server(self, addrinfo, task=None): task.set_daemon() sock = AsyncSocket(socket.socket(addrinfo.family, socket.SOCK_STREAM), keyfile=self.keyfile, certfile=self.certfile) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) try: sock.bind((addrinfo.ip, self.info_port)) except Exception: logger.warning('Could not bind TCP server to %s:%s', addrinfo.ip, self.info_port) raise StopIteration logger.info('dispyadmin TCP server at %s:%s', addrinfo.ip, self.info_port) sock.listen(16) while 1: try: conn, addr = yield sock.accept() except ssl.SSLError as err: logger.debug('SSL connection failed: %s', str(err)) continue except GeneratorExit: break except Exception: logger.debug(traceback.format_exc()) continue Task(self.tcp_req, conn, addr) sock.close()
def sched_udp_proc(self, bind_addr, addrinfo, task=None): task.set_daemon() def relay_msg(msg, task=None): relay = { 'ip_addrs': self.scheduler_ip_addr, 'port': self.scheduler_port, 'version': __version__ } relay['relay'] = 'y' sock = AsyncSocket(socket.socket(addrinfo.family, socket.SOCK_STREAM), keyfile=self.keyfile, certfile=self.certfile) sock.settimeout(dispy.MsgTimeout) yield sock.connect((msg['ip_addr'], msg['port'])) yield sock.sendall(dispy.auth_code(self.secret, msg['sign'])) yield sock.send_msg('PING:'.encode() + serialize(relay)) sock.close() sched_sock = AsyncSocket( socket.socket(addrinfo.family, socket.SOCK_DGRAM)) sched_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) if hasattr(socket, 'SO_REUSEPORT'): sched_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) sched_sock.bind((bind_addr, self.scheduler_port)) if addrinfo.family == socket.AF_INET: if self.ipv4_udp_multicast: mreq = socket.inet_aton(addrinfo.broadcast) + socket.inet_aton( addrinfo.ip) sched_sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq) else: # addrinfo.family == socket.AF_INET6: mreq = socket.inet_pton(addrinfo.family, addrinfo.broadcast) mreq += struct.pack('@I', addrinfo.ifn) sched_sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, mreq) try: sched_sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1) except: pass while 1: msg, addr = yield sched_sock.recvfrom(1024) if not msg.startswith('PING:'.encode()): logger.debug('Ignoring message from %s (%s)', addr[0], addr[1]) continue try: msg = deserialize(msg[len('PING:'.encode()):]) assert msg['version'] == __version__ # assert isinstance(msg['cpus'], int) except: continue if not self.scheduler_ip_addr: continue Task(relay_msg, msg)
def listen_tcp_proc(self, addrinfo, task=None): task.set_daemon() tcp_sock = AsyncSocket(socket.socket(addrinfo.family, socket.SOCK_STREAM)) tcp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) tcp_sock.bind((addrinfo.ip, self.listen_port)) tcp_sock.listen(8) bc_sock = AsyncSocket(socket.socket(addrinfo.family, socket.SOCK_DGRAM)) if addrinfo.family == socket.AF_INET: bc_sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) else: # addrinfo.sock_family == socket.AF_INET6 bc_sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, struct.pack('@i', 1)) bc_sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, addrinfo.ifn) bc_sock.bind((addrinfo.ip, 0)) auth_len = len(dispy.auth_code('', '')) def tcp_req(conn, addr, task=None): conn.settimeout(5) try: msg = yield conn.recvall(auth_len) msg = yield conn.recv_msg() except: logger.debug(traceback.format_exc()) logger.debug('Ignoring invalid TCP message from %s:%s', addr[0], addr[1]) raise StopIteration finally: conn.close() logger.debug('Ping message from %s (%s)', addr[0], addr[1]) try: info = deserialize(msg[len('PING:'.encode()):]) if info['version'] != __version__: logger.warning('Ignoring %s due to version mismatch: %s / %s', info['ip_addrs'], info['version'], __version__) raise StopIteration # TODO: since dispynetrelay is not aware of computations # closing, if more than one client sends ping, nodes will # respond to different clients self.scheduler_ip_addrs = info['ip_addrs'] + [addr[0]] self.scheduler_port = info['port'] except: logger.debug('Ignoring ping message from %s (%s)', addr[0], addr[1]) logger.debug(traceback.format_exc()) raise StopIteration if info.get('relay', None): logger.debug('Ignoring ping back (from %s)', addr[0]) raise StopIteration logger.debug('relaying ping from %s / %s', info['ip_addrs'], addr[0]) if self.node_port == self.listen_port: info['relay'] = 'y' # 'check if this message loops back to self yield bc_sock.sendto('PING:'.encode() + serialize(info), (self._broadcast, self.node_port)) while 1: conn, addr = yield tcp_sock.accept() Task(tcp_req, conn, addr)
def relay_udp_proc(self, bind_addr, addrinfo, task=None): task.set_daemon() relay_sock = AsyncSocket( socket.socket(addrinfo.family, socket.SOCK_DGRAM)) relay_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) if hasattr(socket, 'SO_REUSEPORT'): try: relay_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) except Exception: pass relay_sock.bind((bind_addr, self.relay_port)) if addrinfo.family == socket.AF_INET: if self.ipv4_udp_multicast: mreq = socket.inet_aton(addrinfo.broadcast) + socket.inet_aton( addrinfo.ip) relay_sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq) else: # addrinfo.family == socket.AF_INET6: mreq = socket.inet_pton(addrinfo.family, addrinfo.broadcast) mreq += struct.pack('@I', addrinfo.ifn) relay_sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, mreq) try: relay_sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1) except Exception: pass while 1: msg, addr = yield relay_sock.recvfrom(1024) if not msg.startswith('PING:'.encode()): logger.debug('Ignoring message from %s', addr[0]) continue if addr[0] in self.ip_addrs: logger.debug('Ignoring loop back ping from %s' % addr[0]) continue try: msg = deserialize(msg[len('PING:'.encode()):]) if msg['version'] != __version__: logger.warning( 'Ignoring %s due to version mismatch: %s / %s', msg['ip_addrs'], msg['version'], __version__) continue except Exception: logger.debug('Ignoring ping message from %s (%s)', addr[0], addr[1]) logger.debug(traceback.format_exc()) continue Task(self.verify_broadcast, addrinfo, msg)
def timer_proc(self, task=None): task.set_daemon() Task(self.discover_nodes) last_ping = time.time() interval = self.poll_interval while 1: now = time.time() if (now - last_ping) >= self.ping_interval: Task(self.discover_nodes) last_ping = now if (now - self.client_uid_time) > (5 * self.ping_interval): interval *= 2 with self.lock: nodes = list(self.nodes.values()) # TODO: it may be better to have nodes send updates periodically for node in nodes: if node._priv.auth: Task(self.update_node_info, node) update = yield task.sleep(interval) if update: interval = update
def communicate(self, input=None): """Similar to Popen's communicate. Must be used with 'yield' as 'stdout, stderr = yield async_pipe.communicate()' 'input' must be either data or an object with 'read' method (i.e., regular file object or AsyncFile object). """ def write_proc(fd, input, task=None): size = 16384 if isinstance(input, str): n = yield fd.write(input, full=True) if n != len(input): raise IOError('write failed') else: # TODO: how to know if 'input' is file object for # on-disk file? if hasattr(input, 'seek') and hasattr(input, 'fileno'): read_func = partial_func(os.read, input.fileno()) else: read_func = input.read while 1: data = yield read_func(size) if not data: break n = yield fd.write(data, full=True) if n != len(data): raise IOError('write failed') input.close() fd.close() def read_proc(fd, task=None): size = 16384 buflist = [] while 1: buf = yield fd.read(size) if not buf: break buflist.append(buf) fd.close() data = ''.join(buflist) raise StopIteration(data) if self.stdout: stdout_task = Task(read_proc, self.stdout) if self.stderr: stderr_task = Task(read_proc, self.stderr) if input and self.stdin: stdin_task = Task(write_proc, self.stdin, input) yield stdin_task.finish() raise StopIteration((yield stdout_task.finish()) if self.stdout else None, (yield stderr_task.finish()) if self.stderr else None)
cfg = open(cfg, 'w') config.write(cfg) cfg.close() exit(0) dispy.config.HTTPServerPort = int(config.pop('http_port')) dispy.config.ClientPort = str(int(config.pop('info_port'))) dispy.config.NodePort = str(int(config.pop('node_port'))) server = DispyAdminServer(DocumentRoot, **config) while 1: try: cmd = input('\nEnter "quit" or "exit" to quit,\n' ' "secret" to add admin secret,\n' ' "scan" to find nodes :') except KeyboardInterrupt: break cmd = cmd.strip().lower() if cmd == 'secret': secret = input('Enter admin secret :') secret = secret.strip() if secret: server.set_secret(secret) elif cmd == 'scan': Task(server.discover_nodes) elif cmd == 'quit' or cmd == 'exit': break else: print('Ignoring invalid command') server.shutdown()
class DispyAdminServer(object): class _NodePriv(object): def __init__(self, port, sock_family, sign): self.port = port self.sock_family = sock_family self.sign = sign self.auth = None class _NodeInfo(object): ip_re = re.compile(r'^((\d+\.\d+\.\d+\.\d+)|([0-9a-f:]+))$') def __init__(self, ip_addr, port, sock_family, sign): self.ip_addr = ip_addr self.name = '' self.max_cpus = 0 self.cpus = 0 self.scheduler_ip = None self.clients_done = 0 self.jobs_done = 0 self.cpu_time = 0 self.busy = 0 self.avail_info = None self.service_start = 0 self.service_stop = 0 self.service_end = 0 self.serve = None self.update_time = 0 self._priv = DispyAdminServer._NodePriv(port, sock_family, sign) class _HTTPRequestHandler(BaseHTTPRequestHandler): def __init__(self, ctx, DocumentRoot, *args): self._ctx = ctx self.DocumentRoot = DocumentRoot BaseHTTPRequestHandler.__init__(self, *args) @staticmethod def json_encode_nodes(arg): nodes = [dict(node.__dict__) for node in dict_iter(arg, 'values')] for node in nodes: node.pop('_priv', None) if node['avail_info']: node['avail_info'] = node['avail_info'].__dict__ return nodes def log_message(self, fmt, *args): # logger.debug('HTTP client %s: %s' % (self.client_address[0], fmt % args)) return def do_GET(self): path = urlparse(self.path).path.lstrip('/') if path == '' or path == 'index.html': path = 'admin.html' path = os.path.join(self.DocumentRoot, path) try: with open(path) as fd: data = fd.read() if path.endswith('.html'): if path.endswith('admin.html') or path.endswith( 'admin_node.html'): data = data % { 'POLL_INTERVAL': str(self._ctx.poll_interval), 'NODE_PORT': str(self._ctx.node_port) } content_type = 'text/html' elif path.endswith('.js'): content_type = 'text/javascript' elif path.endswith('.css'): content_type = 'text/css' elif path.endswith('.ico'): content_type = 'image/x-icon' self.send_response(200) self.send_header('Content-Type', content_type) if content_type == 'text/css' or content_type == 'text/javascript': self.send_header('Cache-Control', 'private, max-age=86400') self.end_headers() self.wfile.write(data.encode()) return except Exception: logger.warning('HTTP client %s: Could not read/send "%s"', self.client_address[0], path) logger.debug(traceback.format_exc()) self.send_error(404) return def do_POST(self): try: form = cgi.FieldStorage(fp=self.rfile, headers=self.headers, environ={'REQUEST_METHOD': 'POST'}) client_request = self.path[1:] except Exception: logger.debug('Ignoring invalid POST request from %s', self.client_address[0]) self.send_error(400) return if client_request == 'update': uid = None for item in form.list: if item.name == 'uid': uid = item.value.strip() break if uid != self._ctx.client_uid: self.send_error(400, 'invalid uid') return self._ctx.client_uid_time = time.time() self._ctx.lock.acquire() nodes = self.__class__.json_encode_nodes(self._ctx.updates) self._ctx.updates.clear() self._ctx.lock.release() self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(nodes).encode()) return elif client_request == 'node_info': ip_addr = None uid = None for item in form.list: if item.name == 'host': # if it looks like IP address, skip resolving if re.match(DispyAdminServer._NodeInfo.ip_re, item.value): ip_addr = item.value else: ip_addr = dispy._node_ipaddr(item.value) elif item.name == 'uid': uid = item.value.strip() if uid != self._ctx.client_uid: self.send_error(400, 'invalid uid') return self._ctx.client_uid_time = time.time() node = self._ctx.nodes.get(ip_addr, None) self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() if node: node = dict(node.__dict__) node.pop('_priv', None) if node['avail_info']: node['avail_info'] = node['avail_info'].__dict__ else: node = {} self.wfile.write(json.dumps(node).encode()) return elif client_request == 'status': uid = None for item in form.list: if item.name == 'uid': uid = item.value.strip() break if uid != self._ctx.client_uid: self.send_error(400, 'invalid uid') return self._ctx.client_uid_time = time.time() self._ctx.lock.acquire() nodes = self.__class__.json_encode_nodes(self._ctx.nodes) self._ctx.lock.release() self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(nodes).encode()) return elif client_request == 'get_uid': uid = None for item in form.list: if item.name == 'uid': uid = item.value.strip() elif item.name == 'poll_interval': try: poll_interval = int(item.value) assert poll_interval >= 5 except Exception: self.send_error(400, 'invalid poll interval') return # TODO: only allow from http server? uid = self._ctx.set_uid(self.client_address[0], poll_interval, uid) if not uid: self.send_error(400, 'invalid uid') return self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(uid).encode()) return elif client_request == 'set_secret': secret = None uid = None for item in form.list: if item.name == 'secret': secret = item.value.strip() elif item.name == 'uid': uid = item.value.strip() if secret and uid == self._ctx.client_uid: self._ctx.client_uid_time = time.time() self._ctx.set_secret(secret) self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(0).encode()) else: self.send_error(400) return elif client_request == 'add_node': host = '' port = None uid = None for item in form.list: if item.name == 'host': host = item.value elif item.name == 'port': try: port = int(item.value) except Exception: port = None elif item.name == 'uid': uid = item.value.strip() if uid != self._ctx.client_uid: self.send_error(400, 'invalid uid') return if host and port: ip_addr = dispy._node_ipaddr(host) if ip_addr: info = {'ip_addr': ip_addr, 'port': port} Task(self._ctx.add_node, info) self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(0).encode()) else: self.send_error(400) return elif client_request == 'service_time': hosts = [] svc_time = None control = None uid = None for item in form.list: if item.name == 'hosts': hosts = [str(host) for host in json.loads(item.value)] elif item.name == 'control': control = item.value elif item.name == 'time': svc_time = item.value.strip() elif item.name == 'uid': uid = item.value.strip() if uid != self._ctx.client_uid: self.send_error(400, 'invalid uid') return self._ctx.client_uid_time = time.time() for host in hosts: Task(self._ctx.service_time, host, control, svc_time) self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(0).encode()) return elif client_request == 'set_cpus': hosts = [] cpus = None uid = None for item in form.list: if item.name == 'hosts': hosts = [str(host) for host in json.loads(item.value)] if not hosts: self.send_error(400, 'invalid nodes') return elif item.name == 'cpus': cpus = item.value if cpus is not None: try: cpus = int(item.value) except Exception: self.send_error(400, 'invalid CPUs') return elif item.name == 'uid': uid = item.value.strip() if uid != self._ctx.client_uid: self.send_error(400, 'invalid uid') return for host in hosts: Task(self._ctx.set_cpus, host, cpus) self._ctx.client_uid_time = time.time() self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(0).encode()) return elif client_request == 'serve_clients': host = '' serve = None uid = None for item in form.list: if item.name == 'host': host = item.value elif item.name == 'serve': serve = item.value try: serve = int(serve) except Exception: pass elif item.name == 'uid': uid = item.value.strip() if (uid == self._ctx.client_uid and isinstance(serve, int) and Task(self._ctx.serve_clients, host, serve).value() == 0): self._ctx.client_uid_time = time.time() self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(0).encode()) else: self.send_error(400) return return elif client_request == 'poll_interval': uid = None interval = None for item in form.list: if item.name == 'interval': try: interval = int(item.value) except Exception: if interval is not None: logger.warning( '%s: invalid poll interval "%s" ignored', self._ctx.client_uid, item.value) self.send_error(400) return elif item.name == 'uid': uid = item.value.strip() if (uid == self._ctx.client_uid and self._ctx.set_poll_interval(interval) == 0): self._ctx.client_uid_time = time.time() self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(0).encode()) else: self.send_error(400) return logger.debug('Bad POST request from %s: %s', self.client_address[0], client_request) self.send_error(400) return def __init__(self, DocumentRoot, secret='', http_host='localhost', poll_interval=60, ping_interval=600, ip_addrs=[], ipv4_udp_multicast=False, certfile=None, keyfile=None): http_port = dispy.config.HTTPServerPort self.node_port = eval(dispy.config.NodePort) self.info_port = eval(dispy.config.ClientPort) self.lock = threading.Lock() self.client_uid = None self.client_uid_time = 0 self.nodes = {} self.updates = {} if poll_interval < 1: logger.warning( 'invalid poll_interval value %s; it must be at least 1', poll_interval) poll_interval = 1 self.poll_interval = poll_interval self.ping_interval = ping_interval self.secret = secret self.keyfile = keyfile self.certfile = certfile self.ipv4_udp_multicast = bool(ipv4_udp_multicast) self.addrinfos = {} if not ip_addrs: ip_addrs = [None] for i in range(len(ip_addrs)): ip_addr = ip_addrs[i] addrinfo = dispy.host_addrinfo( host=ip_addr, ipv4_multicast=self.ipv4_udp_multicast) if not addrinfo: logger.warning('Ignoring invalid ip_addr %s', ip_addr) continue self.addrinfos[addrinfo.ip] = addrinfo if not self.addrinfos: raise Exception('No valid IP address found') self.sign = hashlib.sha1(os.urandom(20)) for ip_addr in self.addrinfos: self.sign.update(ip_addr.encode()) self.sign = self.sign.hexdigest() self.auth = dispy.auth_code(self.secret, self.sign) self.tcp_tasks = [] self.udp_tasks = [] udp_addrinfos = {} for addrinfo in self.addrinfos.values(): self.tcp_tasks.append(Task(self.tcp_server, addrinfo)) udp_addrinfos[addrinfo.bind_addr] = addrinfo for bind_addr, addrinfo in udp_addrinfos.items(): self.udp_tasks.append(Task(self.udp_server, addrinfo)) self._server = HTTPServer( (http_host, http_port), lambda *args: self.__class__._HTTPRequestHandler( self, DocumentRoot, *args)) if certfile: self._server.socket = ssl.wrap_socket(self._server.socket, keyfile=keyfile, certfile=certfile, server_side=True) self.timer = Task(self.timer_proc) self._httpd_thread = threading.Thread( target=self._server.serve_forever) self._httpd_thread.daemon = True self._httpd_thread.start() self.client_host = self._server.socket.getsockname()[0] logger.info('Started HTTP%s server at %s:%s', 's' if certfile else '', self.client_host, self._server.socket.getsockname()[1]) def tcp_server(self, addrinfo, task=None): task.set_daemon() sock = AsyncSocket(socket.socket(addrinfo.family, socket.SOCK_STREAM), keyfile=self.keyfile, certfile=self.certfile) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) try: sock.bind((addrinfo.ip, self.info_port)) except Exception: logger.warning('Could not bind TCP server to %s:%s', addrinfo.ip, self.info_port) raise StopIteration logger.info('dispyadmin TCP server at %s:%s', addrinfo.ip, self.info_port) sock.listen(16) while 1: try: conn, addr = yield sock.accept() except ssl.SSLError as err: logger.debug('SSL connection failed: %s', str(err)) continue except GeneratorExit: break except Exception: logger.debug(traceback.format_exc()) continue Task(self.tcp_req, conn, addr) sock.close() def udp_server(self, addrinfo, task=None): task.set_daemon() udp_sock = AsyncSocket( socket.socket(addrinfo.family, socket.SOCK_DGRAM)) udp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) if hasattr(socket, 'SO_REUSEPORT'): try: udp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) except Exception: pass udp_sock.bind((addrinfo.bind_addr, self.info_port)) if addrinfo.family == socket.AF_INET: if self.ipv4_udp_multicast: mreq = socket.inet_aton(addrinfo.broadcast) + socket.inet_aton( addrinfo.ip) udp_sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq) else: # addrinfo.family == socket.AF_INET6: mreq = socket.inet_pton(addrinfo.family, addrinfo.broadcast) mreq += struct.pack('@I', addrinfo.ifn) udp_sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, mreq) try: udp_sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1) except Exception: pass while 1: msg, addr = yield udp_sock.recvfrom(1000) if msg.startswith(b'PING:'): try: info = deserialize(msg[len(b'PING:'):]) if info['version'] != _dispy_version: logger.warning('Ignoring %s due to version mismatch', addr[0]) continue assert info['port'] > 0 assert info['ip_addr'] except Exception: logger.debug('Ignoring node %s', addr[0]) continue node = self.nodes.get(info['ip_addr'], None) if node: if node._priv.sign == info['sign']: Task(self.update_node_info, node) else: node._priv.sign = info['sign'] node._priv.auth = None Task(self.get_node_info, node) else: info['family'] = addrinfo.family Task(self.add_node, info) elif msg.startswith(b'TERMINATED:'): try: info = deserialize(msg[len(b'TERMINATED:'):]) assert info['ip_addr'] except Exception: logger.debug('Ignoring node %s', addr[0]) continue node = self.nodes.get(info['ip_addr'], None) if node and node._priv.sign == info['sign']: with self.lock: self.nodes.pop(info['ip_addr'], None) def tcp_req(self, conn, addr, task=None): conn.settimeout(MsgTimeout) msg = yield conn.recv_msg() if msg.startswith(b'NODE_INFO:'): try: info = deserialize(msg[len(b'NODE_INFO:'):]) dispy.logger.info('info: %s', info) node = info.get('ip_addr', None) if info.get('version', None) != _dispy_version: dispy.logger.warning( 'Ignoring node at %s due to version mismatch (%s != %s)', info.get('ip_addr', None), info.get('version', None), _dispy_version) raise StopIteration assert info['sign'] info['family'] = conn.family except Exception: # dispy.logger.debug(traceback.format_exc()) raise StopIteration finally: conn.close() yield self.add_node(info) raise StopIteration def set_node_info(self, node, info): node.scheduler_ip = info['scheduler_ip'] node.clients_done = info['clients_done'] node.jobs_done = info['jobs_done'] node.cpu_time = info['cpu_time'] node.busy = info['busy'] node.serve = info['serve'] if 'service_start' in info: node.service_start = info['service_start'] node.service_stop = info['service_stop'] node.service_end = info['service_end'] node.avail_info = info['avail_info'] node.update_time = time.time() with self.lock: self.updates[node.ip_addr] = node def get_node_info(self, node, task=None): auth = node._priv.auth if not auth: auth = dispy.auth_code(self.secret, node._priv.sign) sock = AsyncSocket(socket.socket(node._priv.sock_family, socket.SOCK_STREAM), keyfile=self.keyfile, certfile=self.certfile) sock.settimeout(MsgTimeout) try: yield sock.connect((node.ip_addr, node._priv.port)) yield sock.sendall(auth) yield sock.send_msg(b'NODE_INFO:' + serialize({'sign': self.sign})) info = yield sock.recv_msg() except Exception: dispy.logger.debug('Could not get node information from %s:%s', node.ip_addr, node._priv.port) # dispy.logger.debug(traceback.format_exc()) raise StopIteration(-1) finally: sock.close() try: info = deserialize(info) node.name = info['name'] node.cpus = info['cpus'] node.max_cpus = info['max_cpus'] except Exception: sign = info.decode() if node._priv.sign == sign: node.update_time = time.time() raise StopIteration(0) else: node._priv.sign = sign ret = yield self.get_node_info(node, task=task) raise StopIteration(ret) else: node._priv.auth = auth self.set_node_info(node, info) raise StopIteration(0) def add_node(self, info, task=None): sign = info.get('sign', '') family = info.get('family', None) if not family: for addr in socket.getaddrinfo(info['ip_addr'], info['port'], type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP): family = addr[0] break node = DispyAdminServer._NodeInfo(info['ip_addr'], info['port'], family, sign) ret = yield self.get_node_info(node, task=task) if ret == 0: with self.lock: self.nodes[node.ip_addr] = node self.updates[node.ip_addr] = node def set_secret(self, secret, task=None): with self.lock: self.secret = secret for node in self.nodes.values(): if not node._priv.auth: Task(self.get_node_info, node) self.timer.resume() def set_cpus(self, host, cpus, task=None): node = self.nodes.get(host, None) if not node or not node._priv.auth: raise StopIteration(-1) sock = AsyncSocket(socket.socket(node._priv.sock_family, socket.SOCK_STREAM), keyfile=self.keyfile, certfile=self.certfile) sock.settimeout(MsgTimeout) try: yield sock.connect((node.ip_addr, node._priv.port)) yield sock.sendall(node._priv.auth) yield sock.send_msg('SET_CPUS:'.encode() + serialize({'cpus': cpus})) resp = yield sock.recv_msg() info = deserialize(resp) node.cpus = info['cpus'] except Exception: dispy.logger.debug('Setting cpus of %s to %s failed', host, cpus) raise StopIteration(-1) else: raise StopIteration(0) finally: sock.close() def service_time(self, host, control, time, task=None): node = self.nodes.get(dispy._node_ipaddr(host), None) if not node or not node._priv.auth: raise StopIteration(-1) sock = AsyncSocket(socket.socket(node._priv.sock_family, socket.SOCK_STREAM), keyfile=self.keyfile, certfile=self.certfile) sock.settimeout(MsgTimeout) try: yield sock.connect((node.ip_addr, node._priv.port)) yield sock.sendall(node._priv.auth) yield sock.send_msg('SERVICE_TIME:'.encode() + serialize({ 'control': control, 'time': time })) resp = yield sock.recv_msg() info = deserialize(resp) node.service_start = info['service_start'] node.service_stop = info['service_stop'] node.service_end = info['service_end'] resp = 0 except Exception: resp = -1 sock.close() if resp: dispy.logger.debug('Setting service %s time of %s to %s failed', control, host, time) raise StopIteration(resp) def serve_clients(self, host, serve, task=None): node = self.nodes.get(dispy._node_ipaddr(host), None) if not node or not node._priv.auth: raise StopIteration(-1) sock = AsyncSocket(socket.socket(node._priv.sock_family, socket.SOCK_STREAM), keyfile=self.keyfile, certfile=self.certfile) sock.settimeout(MsgTimeout) try: yield sock.connect((node.ip_addr, node._priv.port)) yield sock.sendall(node._priv.auth) yield sock.send_msg('SERVE_CLIENTS:'.encode() + serialize({'serve': serve})) resp = yield sock.recv_msg() info = deserialize(resp) node.serve = info['serve'] resp = 0 except Exception: dispy.logger.debug('Setting serve clients %s to %s failed', host, serve) resp = -1 finally: sock.close() raise StopIteration(resp) def update_node_info(self, node, task=None): sock = AsyncSocket(socket.socket(node._priv.sock_family, socket.SOCK_STREAM), keyfile=self.keyfile, certfile=self.certfile) sock.settimeout(MsgTimeout) try: yield sock.connect((node.ip_addr, node._priv.port)) yield sock.sendall(node._priv.auth) yield sock.send_msg(b'NODE_STATUS:') info = yield sock.recv_msg() info = deserialize(info) if isinstance(info, dict): self.set_node_info(node, info) except Exception: # TODO: remove node if update is long ago? pass finally: sock.close() def set_poll_interval(self, interval): if not isinstance(interval, int): if interval is None: self.timer.resume() return 0 else: return -1 if not interval: self.timer.resume() return 0 elif interval >= 5: self.poll_interval = interval self.timer.resume(update=interval) return 0 else: return -1 def set_uid(self, client_host, poll_interval, uid=None): now = time.time() try: poll_interval = int(poll_interval) assert poll_interval >= 5 except Exception: return None if ((uid == self.client_uid) or (not self.client_uid) or (self.client_host == client_host) or ((now - self.client_uid_time) > min(5 * self.poll_interval, 600))): if not uid: uid = hashlib.sha1(os.urandom(20)).hexdigest() self.client_uid = uid self.client_uid_time = now self.client_host = client_host if self.poll_interval != poll_interval: self.poll_interval = poll_interval self.timer.resume(update=poll_interval) return uid else: logger.warning( 'Ignoring client at %s; currently controlled by client at %s', client_host, self.client_host) return None def discover_nodes(self, task=None): addrinfos = list(self.addrinfos.values()) for addrinfo in addrinfos: info_msg = { 'ip_addr': addrinfo.ip, 'port': self.info_port, 'sign': self.sign, 'version': _dispy_version } bc_sock = AsyncSocket( socket.socket(addrinfo.family, socket.SOCK_DGRAM)) bc_sock.settimeout(MsgTimeout) ttl_bin = struct.pack('@i', 1) if addrinfo.family == socket.AF_INET: if self.ipv4_udp_multicast: bc_sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, ttl_bin) else: bc_sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) else: # addrinfo.family == socket.AF_INET6 bc_sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, ttl_bin) bc_sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, addrinfo.ifn) bc_sock.bind((addrinfo.ip, 0)) try: yield bc_sock.sendto(b'NODE_INFO:' + serialize(info_msg), (addrinfo.broadcast, self.node_port)) except Exception: pass bc_sock.close() def timer_proc(self, task=None): task.set_daemon() Task(self.discover_nodes) last_ping = time.time() interval = self.poll_interval while 1: now = time.time() if (now - last_ping) >= self.ping_interval: Task(self.discover_nodes) last_ping = now if (now - self.client_uid_time) > (5 * self.ping_interval): interval *= 2 with self.lock: nodes = list(self.nodes.values()) # TODO: it may be better to have nodes send updates periodically for node in nodes: if node._priv.auth: Task(self.update_node_info, node) update = yield task.sleep(interval) if update: interval = update def shutdown(self, wait=True): """This method should be called by user program to close the http server. """ if wait: logger.info( 'HTTP server waiting for %s seconds for client updates before quitting' % self.poll_interval) time.sleep(self.poll_interval) self._server.shutdown() self._server.server_close()
def udp_server(self, addrinfo, task=None): task.set_daemon() udp_sock = AsyncSocket( socket.socket(addrinfo.family, socket.SOCK_DGRAM)) udp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) if hasattr(socket, 'SO_REUSEPORT'): try: udp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) except Exception: pass udp_sock.bind((addrinfo.bind_addr, self.info_port)) if addrinfo.family == socket.AF_INET: if self.ipv4_udp_multicast: mreq = socket.inet_aton(addrinfo.broadcast) + socket.inet_aton( addrinfo.ip) udp_sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq) else: # addrinfo.family == socket.AF_INET6: mreq = socket.inet_pton(addrinfo.family, addrinfo.broadcast) mreq += struct.pack('@I', addrinfo.ifn) udp_sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, mreq) try: udp_sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1) except Exception: pass while 1: msg, addr = yield udp_sock.recvfrom(1000) if msg.startswith(b'PING:'): try: info = deserialize(msg[len(b'PING:'):]) if info['version'] != _dispy_version: logger.warning('Ignoring %s due to version mismatch', addr[0]) continue assert info['port'] > 0 assert info['ip_addr'] except Exception: logger.debug('Ignoring node %s', addr[0]) continue node = self.nodes.get(info['ip_addr'], None) if node: if node._priv.sign == info['sign']: Task(self.update_node_info, node) else: node._priv.sign = info['sign'] node._priv.auth = None Task(self.get_node_info, node) else: info['family'] = addrinfo.family Task(self.add_node, info) elif msg.startswith(b'TERMINATED:'): try: info = deserialize(msg[len(b'TERMINATED:'):]) assert info['ip_addr'] except Exception: logger.debug('Ignoring node %s', addr[0]) continue node = self.nodes.get(info['ip_addr'], None) if node and node._priv.sign == info['sign']: with self.lock: self.nodes.pop(info['ip_addr'], None)
class DispyAdminServer(object): class _NodePriv(object): def __init__(self, port, sock_family, sign): self.port = port self.sock_family = sock_family self.sign = sign self.auth = None class _NodeInfo(object): ip_re = re.compile(r'^((\d+\.\d+\.\d+\.\d+)|([0-9a-f:]+))$') def __init__(self, ip_addr, port, sock_family, sign): self.ip_addr = ip_addr self.name = '' self.max_cpus = 0 self.cpus = 0 self.scheduler_ip = None self.clients_done = 0 self.jobs_done = 0 self.cpu_time = 0 self.busy = 0 self.avail_info = None self.service_start = 0 self.service_stop = 0 self.service_end = 0 self.serve = None self.update_time = 0 self._priv = DispyAdminServer._NodePriv(port, sock_family, sign) class _HTTPRequestHandler(BaseHTTPRequestHandler): def __init__(self, ctx, DocumentRoot, *args): self._ctx = ctx self.DocumentRoot = DocumentRoot BaseHTTPRequestHandler.__init__(self, *args) @staticmethod def json_encode_nodes(arg): nodes = [dict(node.__dict__) for node in dict_iter(arg, 'values')] for node in nodes: node.pop('_priv', None) if node['avail_info']: node['avail_info'] = node['avail_info'].__dict__ return nodes def log_message(self, fmt, *args): # logger.debug('HTTP client %s: %s' % (self.client_address[0], fmt % args)) return def do_GET(self): parsed_path = urlparse(self.path) path = parsed_path.path.lstrip('/') if path == '' or path == 'index.html': path = 'admin.html' path = os.path.join(self.DocumentRoot, path) try: with open(path) as fd: data = fd.read() if path.endswith('.html'): if path.endswith('admin.html') or path.endswith('admin_node.html'): data = data % {'POLL_INTERVAL': str(self._ctx.poll_interval), 'NODE_PORT': str(self._ctx.node_port)} content_type = 'text/html' elif path.endswith('.js'): content_type = 'text/javascript' elif path.endswith('.css'): content_type = 'text/css' elif path.endswith('.ico'): content_type = 'image/x-icon' self.send_response(200) self.send_header('Content-Type', content_type) if content_type == 'text/css' or content_type == 'text/javascript': # self.send_header('Cache-Control', 'private, max-age=86400') self.send_header('Cache-Control', 'private, max-age=30') self.end_headers() self.wfile.write(data.encode()) return except Exception: logger.warning('HTTP client %s: Could not read/send "%s"', self.client_address[0], path) logger.debug(traceback.format_exc()) self.send_error(404) return def do_POST(self): try: form = cgi.FieldStorage(fp=self.rfile, headers=self.headers, environ={'REQUEST_METHOD': 'POST'}) client_request = self.path[1:] except Exception: logger.debug('Ignoring invalid POST request from %s', self.client_address[0]) self.send_error(400) return if client_request == 'update': uid = None for item in form.list: if item.name == 'uid': uid = item.value.strip() break if uid != self._ctx.client_uid: self.send_error(400, 'invalid uid') return self._ctx.client_uid_time = time.time() self._ctx.lock.acquire() nodes = self.__class__.json_encode_nodes(self._ctx.updates) self._ctx.updates.clear() self._ctx.lock.release() self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(nodes).encode()) return elif client_request == 'node_info': ip_addr = None uid = None for item in form.list: if item.name == 'host': # if it looks like IP address, skip resolving if re.match(DispyAdminServer._NodeInfo.ip_re, item.value): ip_addr = item.value else: ip_addr = dispy._node_ipaddr(item.value) elif item.name == 'uid': uid = item.value.strip() if uid != self._ctx.client_uid: self.send_error(400, 'invalid uid') return self._ctx.client_uid_time = time.time() node = self._ctx.nodes.get(ip_addr, None) self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() if node: node = dict(node.__dict__) node.pop('_priv', None) if node['avail_info']: node['avail_info'] = node['avail_info'].__dict__ else: node = {} self.wfile.write(json.dumps(node).encode()) return elif client_request == 'status': uid = None for item in form.list: if item.name == 'uid': uid = item.value.strip() break if uid != self._ctx.client_uid: self.send_error(400, 'invalid uid') return self._ctx.client_uid_time = time.time() self._ctx.lock.acquire() nodes = self.__class__.json_encode_nodes(self._ctx.nodes) self._ctx.lock.release() self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(nodes).encode()) return elif client_request == 'get_uid': uid = None for item in form.list: if item.name == 'uid': uid = item.value.strip() break now = time.time() # TODO: only allow from http server? if (self._ctx.client_uid and uid != self._ctx.client_uid and ((now - self._ctx.client_uid_time) < 3600)): self.send_error(400, 'invalid uid') return if not uid: uid = hashlib.sha1(os.urandom(20)).hexdigest() self._ctx.client_uid = uid self._ctx.client_uid_time = time.time() self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(uid).encode()) return elif client_request == 'set_secret': secret = None uid = None for item in form.list: if item.name == 'secret': secret = item.value.strip() elif item.name == 'uid': uid = item.value.strip() if secret and uid == self._ctx.client_uid: self._ctx.client_uid_time = time.time() self._ctx.set_secret(secret) self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(0).encode()) else: self.send_error(400) return elif client_request == 'add_node': host = '' port = None uid = None for item in form.list: if item.name == 'host': host = item.value elif item.name == 'port': try: port = int(item.value) except Exception: port = None elif item.name == 'uid': uid = item.value.strip() if uid != self._ctx.client_uid: self.send_error(400, 'invalid uid') return if host and port: ip_addr = dispy._node_ipaddr(host) if ip_addr: info = {'ip_addr': ip_addr, 'port': port} Task(self._ctx.add_node, info) self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(0).encode()) else: self.send_error(400) return elif client_request == 'service_time': hosts = [] svc_time = None control = None uid = None for item in form.list: if item.name == 'hosts': hosts = [str(host) for host in json.loads(item.value)] elif item.name == 'control': control = item.value elif item.name == 'time': svc_time = item.value elif item.name == 'uid': uid = item.value.strip() if uid != self._ctx.client_uid: self.send_error(400, 'invalid uid') return self._ctx.client_uid_time = time.time() for host in hosts: Task(self._ctx.service_time, host, control, svc_time) self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(0).encode()) return elif client_request == 'set_cpus': host = None cpus = None uid = None for item in form.list: if item.name == 'host': host = item.value.strip() elif item.name == 'cpus': try: cpus = int(item.value.strip()) except Exception: pass elif item.name == 'uid': uid = item.value.strip() if (host and cpus and uid == self._ctx.client_uid and Task(self._ctx.set_cpus, host, cpus).value() == 0): self._ctx.client_uid_time = time.time() self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(0).encode()) else: self.send_error(400) return elif client_request == 'serve_clients': host = '' serve = None uid = None for item in form.list: if item.name == 'host': host = item.value elif item.name == 'serve': serve = item.value try: serve = int(serve) except Exception: pass elif item.name == 'uid': uid = item.value.strip() if (uid == self._ctx.client_uid and isinstance(serve, int) and Task(self._ctx.serve_clients, host, serve).value() == 0): self._ctx.client_uid_time = time.time() self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(0).encode()) else: self.send_error(400) return return elif client_request == 'poll_interval': uid = None interval = None for item in form.list: if item.name == 'interval': try: interval = int(item.value) except Exception: if interval is not None: logger.warning('%s: invalid poll interval "%s" ignored', self._ctx.client_uid, item.value) self.send_error(400) return elif item.name == 'uid': uid = item.value.strip() if (uid == self._ctx.client_uid and self._ctx.set_poll_interval(interval) == 0): self._ctx.client_uid_time = time.time() self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(0).encode()) else: self.send_error(400) return logger.debug('Bad POST request from %s: %s', self.client_address[0], client_request) self.send_error(400) return def __init__(self, DocumentRoot, secret='', http_host='localhost', http_port=8181, info_port=51347, node_port=51348, poll_interval=60, ping_interval=600, ip_addrs=[], ipv4_udp_multicast=False, certfile=None, keyfile=None): self.lock = threading.Lock() self.client_uid = None self.client_uid_time = 0 self.nodes = {} self.updates = {} if poll_interval < 1: logger.warning('invalid poll_interval value %s; it must be at least 1', poll_interval) poll_interval = 1 self.info_port = info_port self.poll_interval = poll_interval self.ping_interval = ping_interval self.secret = secret self.node_port = node_port self.keyfile = keyfile self.certfile = certfile self.ipv4_udp_multicast = bool(ipv4_udp_multicast) self.addrinfos = {} if not ip_addrs: ip_addrs = [None] for i in range(len(ip_addrs)): ip_addr = ip_addrs[i] addrinfo = dispy.host_addrinfo(host=ip_addr, ipv4_multicast=self.ipv4_udp_multicast) if not addrinfo: logger.warning('Ignoring invalid ip_addr %s', ip_addr) continue self.addrinfos[addrinfo.ip] = addrinfo if not self.addrinfos: raise Exception('No valid IP address found') self.http_port = http_port self.sign = hashlib.sha1(os.urandom(20)) for ip_addr in self.addrinfos: self.sign.update(ip_addr.encode()) self.sign = self.sign.hexdigest() self.auth = dispy.auth_code(self.secret, self.sign) self.tcp_tasks = [] self.udp_tasks = [] udp_addrinfos = {} for addrinfo in self.addrinfos.values(): self.tcp_tasks.append(Task(self.tcp_server, addrinfo)) udp_addrinfos[addrinfo.bind_addr] = addrinfo for bind_addr, addrinfo in udp_addrinfos.items(): self.udp_tasks.append(Task(self.udp_server, addrinfo)) self._server = HTTPServer((http_host, http_port), lambda *args: self.__class__._HTTPRequestHandler(self, DocumentRoot, *args)) if certfile: self._server.socket = ssl.wrap_socket(self._server.socket, keyfile=keyfile, certfile=certfile, server_side=True) self.timer = Task(self.timer_proc) self._httpd_thread = threading.Thread(target=self._server.serve_forever) self._httpd_thread.daemon = True self._httpd_thread.start() logger.info('Started HTTP%s server at %s', 's' if certfile else '', ':'.join(map(str, self._server.socket.getsockname()))) def tcp_server(self, addrinfo, task=None): task.set_daemon() sock = AsyncSocket(socket.socket(addrinfo.family, socket.SOCK_STREAM), keyfile=self.keyfile, certfile=self.certfile) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) try: sock.bind((addrinfo.ip, self.info_port)) except Exception: logger.warning('Could not bind TCP server to %s:%s', addrinfo.ip, self.info_port) raise StopIteration logger.debug('dispyadmin TCP server at %s:%s', addrinfo.ip, self.info_port) sock.listen(16) while 1: try: conn, addr = yield sock.accept() except ssl.SSLError as err: logger.debug('SSL connection failed: %s', str(err)) continue except GeneratorExit: break except Exception: logger.debug(traceback.format_exc()) continue Task(self.tcp_req, conn, addr) sock.close() def udp_server(self, addrinfo, task=None): task.set_daemon() udp_sock = AsyncSocket(socket.socket(addrinfo.family, socket.SOCK_DGRAM)) udp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) if hasattr(socket, 'SO_REUSEPORT'): try: udp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) except Exception: pass udp_sock.bind((addrinfo.bind_addr, self.info_port)) if addrinfo.family == socket.AF_INET: if self.ipv4_udp_multicast: mreq = socket.inet_aton(addrinfo.broadcast) + socket.inet_aton(addrinfo.ip) udp_sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq) else: # addrinfo.family == socket.AF_INET6: mreq = socket.inet_pton(addrinfo.family, addrinfo.broadcast) mreq += struct.pack('@I', addrinfo.ifn) udp_sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, mreq) try: udp_sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1) except Exception: pass while 1: msg, addr = yield udp_sock.recvfrom(1000) if msg.startswith('PING:'): try: info = deserialize(msg[len('PING:'):]) if info['version'] != _dispy_version: logger.warning('Ignoring %s due to version mismatch', addr[0]) continue assert info['port'] > 0 assert info['ip_addr'] except Exception: logger.debug('Ignoring node %s', addr[0]) continue node = self.nodes.get(info['ip_addr'], None) if node: if node._priv.sign == info['sign']: Task(self.update_node_info, node) else: node._priv.sign = info['sign'] node._priv.auth = None Task(self.get_node_info, node) else: info['family'] = addrinfo.family Task(self.add_node, info) elif msg.startswith('TERMINATED:'): try: info = deserialize(msg[len('TERMINATED:'):]) assert info['ip_addr'] except Exception: logger.debug('Ignoring node %s', addr[0]) continue node = self.nodes.get(info['ip_addr'], None) if node and node._priv.sign == info['sign']: with self.lock: self.nodes.pop(info['ip_addr'], None) def tcp_req(self, conn, addr, task=None): conn.settimeout(MsgTimeout) msg = yield conn.recv_msg() if msg.startswith('NODE_INFO:'): try: info = deserialize(msg[len('NODE_INFO:'):]) dispy.logger.info('info: %s', info) node = info.get('ip_addr', None) if info.get('version', None) != _dispy_version: dispy.logger.warning('Ignoring node at %s due to version mismatch (%s != %s)', info.get('ip_addr', None), info.get('version', None), _dispy_version) raise StopIteration assert info['sign'] info['family'] = conn.family except Exception: # dispy.logger.debug(traceback.format_exc()) raise StopIteration finally: conn.close() yield self.add_node(info) raise StopIteration def set_node_info(self, node, info): node.scheduler_ip = info['scheduler_ip'] node.clients_done = info['clients_done'] node.jobs_done = info['jobs_done'] node.cpu_time = info['cpu_time'] node.busy = info['busy'] node.serve = info['serve'] if 'service_start' in info: node.service_start = info['service_start'] node.service_stop = info['service_stop'] node.service_end = info['service_end'] node.avail_info = info['avail_info'] node.update_time = time.time() with self.lock: self.updates[node.ip_addr] = node def get_node_info(self, node, task=None): auth = node._priv.auth if not auth: auth = dispy.auth_code(self.secret, node._priv.sign) sock = AsyncSocket(socket.socket(node._priv.sock_family, socket.SOCK_STREAM), keyfile=self.keyfile, certfile=self.certfile) sock.settimeout(MsgTimeout) try: yield sock.connect((node.ip_addr, node._priv.port)) yield sock.sendall(auth) yield sock.send_msg('NODE_INFO:' + serialize({'sign': self.sign})) info = yield sock.recv_msg() except Exception: dispy.logger.debug('Could not get node information from %s:%s', node.ip_addr, node._priv.port) # dispy.logger.debug(traceback.format_exc()) raise StopIteration(-1) finally: sock.close() try: info = deserialize(info) node.name = info['name'] node.cpus = info['cpus'] node.max_cpus = info['max_cpus'] except Exception: sign = info.decode() if node._priv.sign == sign: node.update_time = time.time() raise StopIteration(0) else: node._priv.sign = sign raise StopIteration(yield self.get_node_info(node, task=task)) else: node._priv.auth = auth self.set_node_info(node, info) raise StopIteration(0) def add_node(self, info, task=None): sign = info.get('sign', '') family = info.get('family', None) if not family: for addr in socket.getaddrinfo(info['ip_addr'], info['port'], type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP): family = addr[0] break node = DispyAdminServer._NodeInfo(info['ip_addr'], info['port'], family, sign) ret = yield self.get_node_info(node, task=task) if ret == 0: with self.lock: self.nodes[node.ip_addr] = node self.updates[node.ip_addr] = node def set_secret(self, secret, task=None): with self.lock: self.secret = secret for node in self.nodes.values(): if not node._priv.auth: Task(self.get_node_info, node) self.timer.resume() def set_cpus(self, host, cpus, task=None): node = self.nodes.get(host, None) if not node or not node._priv.auth: raise StopIteration(-1) sock = AsyncSocket(socket.socket(node._priv.sock_family, socket.SOCK_STREAM), keyfile=self.keyfile, certfile=self.certfile) sock.settimeout(MsgTimeout) try: yield sock.connect((node.ip_addr, node._priv.port)) yield sock.sendall(node._priv.auth) yield sock.send_msg('SET_CPUS:' + serialize({'cpus': cpus})) resp = yield sock.recv_msg() info = deserialize(resp) node.cpus = info['cpus'] except Exception: dispy.logger.debug('Setting cpus of %s to %s failed', host, cpus) raise StopIteration(-1) else: raise StopIteration(0) finally: sock.close() def service_time(self, host, control, time, task=None): node = self.nodes.get(dispy._node_ipaddr(host), None) if not node or not node._priv.auth: raise StopIteration(-1) sock = AsyncSocket(socket.socket(node._priv.sock_family, socket.SOCK_STREAM), keyfile=self.keyfile, certfile=self.certfile) sock.settimeout(MsgTimeout) try: yield sock.connect((node.ip_addr, node._priv.port)) yield sock.sendall(node._priv.auth) yield sock.send_msg('SERVICE_TIME:' + serialize({'control': control, 'time': time})) resp = yield sock.recv_msg() info = deserialize(resp) node.service_start = info['service_start'] node.service_stop = info['service_stop'] node.service_end = info['service_end'] resp = 0 except Exception: resp = -1 sock.close() if resp: dispy.logger.debug('Setting service %s time of %s to %s failed', control, host, time) raise StopIteration(resp) def serve_clients(self, host, serve, task=None): node = self.nodes.get(dispy._node_ipaddr(host), None) if not node or not node._priv.auth: raise StopIteration(-1) sock = AsyncSocket(socket.socket(node._priv.sock_family, socket.SOCK_STREAM), keyfile=self.keyfile, certfile=self.certfile) sock.settimeout(MsgTimeout) try: yield sock.connect((node.ip_addr, node._priv.port)) yield sock.sendall(node._priv.auth) yield sock.send_msg('SERVE_CLIENTS:' + serialize({'serve': serve})) resp = yield sock.recv_msg() info = deserialize(resp) node.serve = info['serve'] resp = 0 except Exception: dispy.logger.debug('Setting serve %s to %s failed', host, serve) resp = -1 finally: sock.close() raise StopIteration(resp) def update_node_info(self, node, task=None): sock = AsyncSocket(socket.socket(node._priv.sock_family, socket.SOCK_STREAM), keyfile=self.keyfile, certfile=self.certfile) sock.settimeout(MsgTimeout) try: yield sock.connect((node.ip_addr, node._priv.port)) yield sock.sendall(node._priv.auth) yield sock.send_msg('NODE_STATUS:') info = yield sock.recv_msg() info = deserialize(info) if isinstance(info, dict): self.set_node_info(node, info) except Exception: logger.debug('Could not update node at %s:%s', node.ip_addr, node._priv.port) # TODO: remove node if update is long ago? finally: sock.close() def set_poll_interval(self, interval): if not isinstance(interval, int): if interval is None: self.timer.resume() return 0 else: return -1 if not interval: self.timer.resume() return 0 elif interval >= 5: self.poll_interval = interval self.timer.resume() return 0 else: return -1 def timer_proc(self, task=None): task.set_daemon() last_ping = 0 addrinfos = list(self.addrinfos.values()) while 1: yield task.sleep(self.poll_interval) now = time.time() with self.lock: nodes = list(self.nodes.values()) # TODO: it may be better to have nodes send updates periodically for node in nodes: if node._priv.auth: Task(self.update_node_info, node) if (now - last_ping) >= self.ping_interval: last_ping = now for addrinfo in addrinfos: info_msg = {'ip_addr': addrinfo.ip, 'port': self.info_port, 'sign': self.sign, 'version': _dispy_version} bc_sock = AsyncSocket(socket.socket(addrinfo.family, socket.SOCK_DGRAM)) bc_sock.settimeout(MsgTimeout) ttl_bin = struct.pack('@i', 1) if addrinfo.family == socket.AF_INET: if self.ipv4_udp_multicast: bc_sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, ttl_bin) else: bc_sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) else: # addrinfo.family == socket.AF_INET6 bc_sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, ttl_bin) bc_sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, addrinfo.ifn) bc_sock.bind((addrinfo.ip, 0)) try: yield bc_sock.sendto('NODE_INFO:' + serialize(info_msg), (addrinfo.broadcast, self.node_port)) except Exception: pass bc_sock.close() def shutdown(self, wait=True): """This method should be called by user program to close the http server. """ if wait: logger.info( 'HTTP server waiting for %s seconds for client updates before quitting' % self.poll_interval) time.sleep(self.poll_interval) self._server.shutdown() self._server.server_close()
def __init__(self, ip_addrs=[], node_port=51348, relay_port=0, scheduler_nodes=[], scheduler_port=51347, ipv4_udp_multicast=False, secret='', certfile=None, keyfile=None): addrinfos = [] if not ip_addrs: ip_addrs = [None] for i in range(len(ip_addrs)): ip_addr = ip_addrs[i] addrinfo = dispy.host_addrinfo(host=ip_addr) if not addrinfo: logger.warning('Ignoring invalid ip_addr %s', ip_addr) continue addrinfos.append(addrinfo) self.node_port = node_port if not relay_port: relay_port = node_port self.relay_port = relay_port self.ipv4_udp_multicast = bool(ipv4_udp_multicast) self.ip_addrs = set() self.scheduler_ip_addr = None self.scheduler_port = scheduler_port self.secret = secret if certfile: self.certfile = os.path.abspath(certfile) else: self.certfile = None if keyfile: self.keyfile = os.path.abspath(keyfile) else: self.keyfile = None udp_addrinfos = {} for addrinfo in addrinfos: self.ip_addrs.add(addrinfo.ip) if addrinfo.family == socket.AF_INET and self.ipv4_udp_multicast: addrinfo.broadcast = dispy.IPV4_MULTICAST_GROUP Task(self.relay_tcp_proc, addrinfo) if os.name == 'nt': bind_addr = addrinfo.ip elif sys.platform == 'darwin': if addrinfo.family == socket.AF_INET and ( not self.ipv4_udp_multicast): bind_addr = '' else: bind_addr = addrinfo.broadcast else: bind_addr = addrinfo.broadcast udp_addrinfos[bind_addr] = addrinfo scheduler_ip_addrs = [] for addr in scheduler_nodes: addr = dispy._node_ipaddr(addr) if addr: scheduler_ip_addrs.append(addr) for bind_addr, addrinfo in udp_addrinfos.items(): Task(self.relay_udp_proc, bind_addr, addrinfo) Task(self.sched_udp_proc, bind_addr, addrinfo) for addr in scheduler_ip_addrs: msg = { 'version': __version__, 'ip_addrs': [addr], 'port': self.scheduler_port, 'sign': None } Task(self.verify_broadcast, addrinfo, msg) logger.info('version %s started', dispy._dispy_version)
def do_POST(self): try: form = cgi.FieldStorage(fp=self.rfile, headers=self.headers, environ={'REQUEST_METHOD': 'POST'}) client_request = self.path[1:] except Exception: logger.debug('Ignoring invalid POST request from %s', self.client_address[0]) self.send_error(400) return if client_request == 'update': uid = None for item in form.list: if item.name == 'uid': uid = item.value.strip() break if uid != self._ctx.client_uid: self.send_error(400, 'invalid uid') return self._ctx.client_uid_time = time.time() self._ctx.lock.acquire() nodes = self.__class__.json_encode_nodes(self._ctx.updates) self._ctx.updates.clear() self._ctx.lock.release() self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(nodes).encode()) return elif client_request == 'node_info': ip_addr = None uid = None for item in form.list: if item.name == 'host': # if it looks like IP address, skip resolving if re.match(DispyAdminServer._NodeInfo.ip_re, item.value): ip_addr = item.value else: ip_addr = dispy._node_ipaddr(item.value) elif item.name == 'uid': uid = item.value.strip() if uid != self._ctx.client_uid: self.send_error(400, 'invalid uid') return self._ctx.client_uid_time = time.time() node = self._ctx.nodes.get(ip_addr, None) self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() if node: node = dict(node.__dict__) node.pop('_priv', None) if node['avail_info']: node['avail_info'] = node['avail_info'].__dict__ else: node = {} self.wfile.write(json.dumps(node).encode()) return elif client_request == 'status': uid = None for item in form.list: if item.name == 'uid': uid = item.value.strip() break if uid != self._ctx.client_uid: self.send_error(400, 'invalid uid') return self._ctx.client_uid_time = time.time() self._ctx.lock.acquire() nodes = self.__class__.json_encode_nodes(self._ctx.nodes) self._ctx.lock.release() self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(nodes).encode()) return elif client_request == 'get_uid': uid = None for item in form.list: if item.name == 'uid': uid = item.value.strip() elif item.name == 'poll_interval': try: poll_interval = int(item.value) assert poll_interval >= 5 except Exception: self.send_error(400, 'invalid poll interval') return # TODO: only allow from http server? uid = self._ctx.set_uid(self.client_address[0], poll_interval, uid) if not uid: self.send_error(400, 'invalid uid') return self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(uid).encode()) return elif client_request == 'set_secret': secret = None uid = None for item in form.list: if item.name == 'secret': secret = item.value.strip() elif item.name == 'uid': uid = item.value.strip() if secret and uid == self._ctx.client_uid: self._ctx.client_uid_time = time.time() self._ctx.set_secret(secret) self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(0).encode()) else: self.send_error(400) return elif client_request == 'add_node': host = '' port = None uid = None for item in form.list: if item.name == 'host': host = item.value elif item.name == 'port': try: port = int(item.value) except Exception: port = None elif item.name == 'uid': uid = item.value.strip() if uid != self._ctx.client_uid: self.send_error(400, 'invalid uid') return if host and port: ip_addr = dispy._node_ipaddr(host) if ip_addr: info = {'ip_addr': ip_addr, 'port': port} Task(self._ctx.add_node, info) self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(0).encode()) else: self.send_error(400) return elif client_request == 'service_time': hosts = [] svc_time = None control = None uid = None for item in form.list: if item.name == 'hosts': hosts = [str(host) for host in json.loads(item.value)] elif item.name == 'control': control = item.value elif item.name == 'time': svc_time = item.value.strip() elif item.name == 'uid': uid = item.value.strip() if uid != self._ctx.client_uid: self.send_error(400, 'invalid uid') return self._ctx.client_uid_time = time.time() for host in hosts: Task(self._ctx.service_time, host, control, svc_time) self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(0).encode()) return elif client_request == 'set_cpus': hosts = [] cpus = None uid = None for item in form.list: if item.name == 'hosts': hosts = [str(host) for host in json.loads(item.value)] if not hosts: self.send_error(400, 'invalid nodes') return elif item.name == 'cpus': cpus = item.value if cpus is not None: try: cpus = int(item.value) except Exception: self.send_error(400, 'invalid CPUs') return elif item.name == 'uid': uid = item.value.strip() if uid != self._ctx.client_uid: self.send_error(400, 'invalid uid') return for host in hosts: Task(self._ctx.set_cpus, host, cpus) self._ctx.client_uid_time = time.time() self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(0).encode()) return elif client_request == 'serve_clients': host = '' serve = None uid = None for item in form.list: if item.name == 'host': host = item.value elif item.name == 'serve': serve = item.value try: serve = int(serve) except Exception: pass elif item.name == 'uid': uid = item.value.strip() if (uid == self._ctx.client_uid and isinstance(serve, int) and Task(self._ctx.serve_clients, host, serve).value() == 0): self._ctx.client_uid_time = time.time() self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(0).encode()) else: self.send_error(400) return return elif client_request == 'poll_interval': uid = None interval = None for item in form.list: if item.name == 'interval': try: interval = int(item.value) except Exception: if interval is not None: logger.warning( '%s: invalid poll interval "%s" ignored', self._ctx.client_uid, item.value) self.send_error(400) return elif item.name == 'uid': uid = item.value.strip() if (uid == self._ctx.client_uid and self._ctx.set_poll_interval(interval) == 0): self._ctx.client_uid_time = time.time() self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(0).encode()) else: self.send_error(400) return logger.debug('Bad POST request from %s: %s', self.client_address[0], client_request) self.send_error(400) return