Ejemplo n.º 1
0
 def serve_clients(self, host, serve, task=None):
     node = self.nodes.get(dispy._node_ipaddr(host), None)
     if not node or not node._priv.auth:
         raise StopIteration(-1)
     sock = AsyncSocket(socket.socket(node._priv.sock_family,
                                      socket.SOCK_STREAM),
                        keyfile=self.keyfile,
                        certfile=self.certfile)
     sock.settimeout(MsgTimeout)
     try:
         yield sock.connect((node.ip_addr, node._priv.port))
         yield sock.sendall(node._priv.auth)
         yield sock.send_msg('SERVE_CLIENTS:'.encode() +
                             serialize({'serve': serve}))
         resp = yield sock.recv_msg()
         info = deserialize(resp)
         node.serve = info['serve']
         resp = 0
     except Exception:
         dispy.logger.debug('Setting serve clients %s to %s failed', host,
                            serve)
         resp = -1
     finally:
         sock.close()
     raise StopIteration(resp)
Ejemplo n.º 2
0
    def __init__(self,
                 ip_addrs=[],
                 node_port=51348,
                 listen_port=0,
                 scheduler_node=None,
                 scheduler_port=51347):
        addrinfos = []
        if not ip_addrs:
            ip_addrs = [None]
        for i in range(len(ip_addrs)):
            ip_addr = ip_addrs[i]
            addrinfo = dispy.node_addrinfo(ip_addr)
            if not addrinfo:
                logger.warning('Ignoring invalid ip_addr %s', ip_addr)
                continue
            addrinfos.append(addrinfo)

        if not listen_port:
            listen_port = node_port
        self.node_port = node_port
        self.listen_port = listen_port
        self.scheduler_port = scheduler_port
        self.scheduler_ip_addrs = list(
            filter(lambda ip: bool(ip), [dispy._node_ipaddr(scheduler_node)]))

        for addrinfo in addrinfos:
            self.listen_udp_task = Task(self.listen_udp_proc, addrinfo)
            self.listen_tcp_task = Task(self.listen_tcp_proc, addrinfo)
            self.sched_udp_task = Task(self.sched_udp_proc, addrinfo)

        logger.info('version %s started', dispy._dispy_version)
Ejemplo n.º 3
0
 def service_time(self, host, control, time, task=None):
     node = self.nodes.get(dispy._node_ipaddr(host), None)
     if not node or not node._priv.auth:
         raise StopIteration(-1)
     sock = AsyncSocket(socket.socket(node._priv.sock_family,
                                      socket.SOCK_STREAM),
                        keyfile=self.keyfile,
                        certfile=self.certfile)
     sock.settimeout(MsgTimeout)
     try:
         yield sock.connect((node.ip_addr, node._priv.port))
         yield sock.sendall(node._priv.auth)
         yield sock.send_msg('SERVICE_TIME:'.encode() +
                             serialize({
                                 'control': control,
                                 'time': time
                             }))
         resp = yield sock.recv_msg()
         info = deserialize(resp)
         node.service_start = info['service_start']
         node.service_stop = info['service_stop']
         node.service_end = info['service_end']
         resp = 0
     except Exception:
         resp = -1
     sock.close()
     if resp:
         dispy.logger.debug('Setting service %s time of %s to %s failed',
                            control, host, time)
     raise StopIteration(resp)
Ejemplo n.º 4
0
 def __init__(self, node_port=51348, listen_port=0, scheduler_node=None, scheduler_port=51347):
     if not listen_port:
         listen_port = node_port
     self.node_port = node_port
     self.listen_port = listen_port
     self.scheduler_port = scheduler_port
     self.scheduler_ip_addrs = list(filter(lambda ip: bool(ip), [dispy._node_ipaddr(scheduler_node)]))
     self.listen_udp_coro = asyncoro.Coro(self.listen_udp_proc)
     self.listen_tcp_coro = asyncoro.Coro(self.listen_tcp_proc)
     self.sched_udp_coro = asyncoro.Coro(self.sched_udp_proc)
Ejemplo n.º 5
0
    def __init__(self, ip_addrs=[], node_port=51348, relay_port=0,
                 scheduler_nodes=[], scheduler_port=51347, ipv4_udp_multicast=False,
                  secret='', certfile=None, keyfile=None):
        self.ipv4_udp_multicast = bool(ipv4_udp_multicast)
        addrinfos = []
        if not ip_addrs:
            ip_addrs = [None]
        for i in range(len(ip_addrs)):
            ip_addr = ip_addrs[i]
            addrinfo = dispy.host_addrinfo(host=ip_addr, ipv4_multicast=self.ipv4_udp_multicast)
            if not addrinfo:
                logger.warning('Ignoring invalid ip_addr %s', ip_addr)
                continue
            addrinfos.append(addrinfo)

        self.node_port = node_port
        if not relay_port:
            relay_port = node_port
        self.relay_port = relay_port
        self.ip_addrs = set()
        self.scheduler_ip_addr = None
        self.scheduler_port = scheduler_port
        self.secret = secret
        if certfile:
            self.certfile = os.path.abspath(certfile)
        else:
            self.certfile = None
        if keyfile:
            self.keyfile = os.path.abspath(keyfile)
        else:
            self.keyfile = None

        udp_addrinfos = {}
        for addrinfo in addrinfos:
            self.ip_addrs.add(addrinfo.ip)
            Task(self.relay_tcp_proc, addrinfo)
            udp_addrinfos[addrinfo.bind_addr] = addrinfo

        scheduler_ip_addrs = []
        for addr in scheduler_nodes:
            addr = dispy._node_ipaddr(addr)
            if addr:
                scheduler_ip_addrs.append(addr)

        for bind_addr, addrinfo in udp_addrinfos.items():
            Task(self.relay_udp_proc, bind_addr, addrinfo)
            Task(self.sched_udp_proc, bind_addr, addrinfo)
            for addr in scheduler_ip_addrs:
                msg = {'version': __version__, 'ip_addrs': [addr], 'port': self.scheduler_port,
                       'sign': None}
                Task(self.verify_broadcast, addrinfo, msg)

        logger.info('version %s started', dispy._dispy_version)
Ejemplo n.º 6
0
    def __init__(self, ip_addrs=[], relay_port=0, scheduler_nodes=[], scheduler_port=0,
                 ipv4_udp_multicast=False, secret='', certfile=None, keyfile=None):
        self.ipv4_udp_multicast = bool(ipv4_udp_multicast)
        addrinfos = []
        if not ip_addrs:
            ip_addrs = [None]
        for i in range(len(ip_addrs)):
            ip_addr = ip_addrs[i]
            addrinfo = dispy.host_addrinfo(host=ip_addr, ipv4_multicast=self.ipv4_udp_multicast)
            if not addrinfo:
                logger.warning('Ignoring invalid ip_addr %s', ip_addr)
                continue
            addrinfos.append(addrinfo)

        self.node_port = eval(dispy.config.NodePort)
        self.scheduler_port = scheduler_port
        self.relay_port = relay_port
        self.ip_addrs = set()
        self.scheduler_ip_addr = None
        self.secret = secret
        if certfile:
            self.certfile = os.path.abspath(certfile)
        else:
            self.certfile = None
        if keyfile:
            self.keyfile = os.path.abspath(keyfile)
        else:
            self.keyfile = None

        udp_addrinfos = {}
        for addrinfo in addrinfos:
            self.ip_addrs.add(addrinfo.ip)
            Task(self.relay_tcp_proc, addrinfo)
            udp_addrinfos[addrinfo.bind_addr] = addrinfo

        scheduler_ip_addrs = []
        for addr in scheduler_nodes:
            addr = dispy._node_ipaddr(addr)
            if addr:
                scheduler_ip_addrs.append(addr)

        for bind_addr, addrinfo in udp_addrinfos.items():
            Task(self.relay_udp_proc, bind_addr, addrinfo)
            Task(self.sched_udp_proc, bind_addr, addrinfo)
            for addr in scheduler_ip_addrs:
                msg = {'version': __version__, 'ip_addrs': [addr], 'port': self.scheduler_port,
                       'sign': None}
                Task(self.verify_broadcast, addrinfo, msg)

        logger.info('version %s started', dispy._dispy_version)
Ejemplo n.º 7
0
 def __init__(self,
              node_port=51348,
              listen_port=0,
              scheduler_node=None,
              scheduler_port=51347):
     if not listen_port:
         listen_port = node_port
     self.node_port = node_port
     self.listen_port = listen_port
     self.scheduler_port = scheduler_port
     self.scheduler_ip_addrs = list(
         filter(lambda ip: bool(ip), [dispy._node_ipaddr(scheduler_node)]))
     self.listen_udp_coro = asyncoro.Coro(self.listen_udp_proc)
     self.listen_tcp_coro = asyncoro.Coro(self.listen_tcp_proc)
     self.sched_udp_coro = asyncoro.Coro(self.sched_udp_proc)
Ejemplo n.º 8
0
 def serve_clients(self, host, serve, task=None):
     node = self.nodes.get(dispy._node_ipaddr(host), None)
     if not node or not node._priv.auth:
         raise StopIteration(-1)
     sock = AsyncSocket(socket.socket(node._priv.sock_family, socket.SOCK_STREAM),
                        keyfile=self.keyfile, certfile=self.certfile)
     sock.settimeout(MsgTimeout)
     try:
         yield sock.connect((node.ip_addr, node._priv.port))
         yield sock.sendall(node._priv.auth)
         yield sock.send_msg('SERVE_CLIENTS:' + serialize({'serve': serve}))
         resp = yield sock.recv_msg()
         info = deserialize(resp)
         node.serve = info['serve']
         resp = 0
     except Exception:
         dispy.logger.debug('Setting serve %s to %s failed', host, serve)
         resp = -1
     finally:
         sock.close()
     raise StopIteration(resp)
Ejemplo n.º 9
0
 def service_time(self, host, control, time, task=None):
     node = self.nodes.get(dispy._node_ipaddr(host), None)
     if not node or not node._priv.auth:
         raise StopIteration(-1)
     sock = AsyncSocket(socket.socket(node._priv.sock_family, socket.SOCK_STREAM),
                        keyfile=self.keyfile, certfile=self.certfile)
     sock.settimeout(MsgTimeout)
     try:
         yield sock.connect((node.ip_addr, node._priv.port))
         yield sock.sendall(node._priv.auth)
         yield sock.send_msg('SERVICE_TIME:' + serialize({'control': control, 'time': time}))
         resp = yield sock.recv_msg()
         info = deserialize(resp)
         node.service_start = info['service_start']
         node.service_stop = info['service_stop']
         node.service_end = info['service_end']
         resp = 0
     except Exception:
         resp = -1
     sock.close()
     if resp:
         dispy.logger.debug('Setting service %s time of %s to %s failed', control, host, time)
     raise StopIteration(resp)
Ejemplo n.º 10
0
 def __init__(self,
              ip_addr=None,
              node_port=51348,
              listen_port=0,
              scheduler_node=None,
              scheduler_port=51347):
     self.addrinfo = dispy.node_addrinfo(ip_addr)
     if not listen_port:
         listen_port = node_port
     self.node_port = node_port
     self.listen_port = listen_port
     self.scheduler_port = scheduler_port
     self.scheduler_ip_addrs = list(
         filter(lambda ip: bool(ip), [dispy._node_ipaddr(scheduler_node)]))
     self.listen_udp_coro = asyncoro.Coro(self.listen_udp_proc)
     self.listen_tcp_coro = asyncoro.Coro(self.listen_tcp_proc)
     self.sched_udp_coro = asyncoro.Coro(self.sched_udp_proc)
     if self.addrinfo[0] == socket.AF_INET:
         self._broadcast = '<broadcast>'
         if netifaces:
             for iface in netifaces.interfaces():
                 for link in netifaces.ifaddresses(iface).get(
                         netifaces.AF_INET, []):
                     if link['addr'] == self.addrinfo[4][0]:
                         self._broadcast = link.get('broadcast',
                                                    '<broadcast>')
                         break
                 else:
                     continue
                 break
     else:  # self.sock_family == socket.AF_INET6
         self._broadcast = 'ff02::1'
         addrinfo = socket.getaddrinfo(self._broadcast, None)[0]
         self.mreq = socket.inet_pton(addrinfo[0], addrinfo[4][0])
         self.mreq += struct.pack('@I', self.addrinfo[4][-1])
     logger.info('version %s started', dispy._dispy_version)
Ejemplo n.º 11
0
        def do_POST(self):
            try:
                form = cgi.FieldStorage(fp=self.rfile,
                                        headers=self.headers,
                                        environ={'REQUEST_METHOD': 'POST'})
                client_request = self.path[1:]
            except Exception:
                logger.debug('Ignoring invalid POST request from %s',
                             self.client_address[0])
                self.send_error(400)
                return

            if client_request == 'update':
                uid = None
                for item in form.list:
                    if item.name == 'uid':
                        uid = item.value.strip()
                        break
                if uid != self._ctx.client_uid:
                    self.send_error(400, 'invalid uid')
                    return
                self._ctx.client_uid_time = time.time()
                self._ctx.lock.acquire()
                nodes = self.__class__.json_encode_nodes(self._ctx.updates)
                self._ctx.updates.clear()
                self._ctx.lock.release()
                self.send_response(200)
                self.send_header('Content-Type',
                                 'application/json; charset=utf-8')
                self.end_headers()
                self.wfile.write(json.dumps(nodes).encode())
                return

            elif client_request == 'node_info':
                ip_addr = None
                uid = None
                for item in form.list:
                    if item.name == 'host':
                        # if it looks like IP address, skip resolving
                        if re.match(DispyAdminServer._NodeInfo.ip_re,
                                    item.value):
                            ip_addr = item.value
                        else:
                            ip_addr = dispy._node_ipaddr(item.value)
                    elif item.name == 'uid':
                        uid = item.value.strip()
                if uid != self._ctx.client_uid:
                    self.send_error(400, 'invalid uid')
                    return
                self._ctx.client_uid_time = time.time()
                node = self._ctx.nodes.get(ip_addr, None)
                self.send_response(200)
                self.send_header('Content-Type',
                                 'application/json; charset=utf-8')
                self.end_headers()
                if node:
                    node = dict(node.__dict__)
                    node.pop('_priv', None)
                    if node['avail_info']:
                        node['avail_info'] = node['avail_info'].__dict__
                else:
                    node = {}
                self.wfile.write(json.dumps(node).encode())
                return

            elif client_request == 'status':
                uid = None
                for item in form.list:
                    if item.name == 'uid':
                        uid = item.value.strip()
                        break
                if uid != self._ctx.client_uid:
                    self.send_error(400, 'invalid uid')
                    return
                self._ctx.client_uid_time = time.time()
                self._ctx.lock.acquire()
                nodes = self.__class__.json_encode_nodes(self._ctx.nodes)
                self._ctx.lock.release()
                self.send_response(200)
                self.send_header('Content-Type',
                                 'application/json; charset=utf-8')
                self.end_headers()
                self.wfile.write(json.dumps(nodes).encode())
                return

            elif client_request == 'get_uid':
                uid = None
                for item in form.list:
                    if item.name == 'uid':
                        uid = item.value.strip()
                    elif item.name == 'poll_interval':
                        try:
                            poll_interval = int(item.value)
                            assert poll_interval >= 5
                        except Exception:
                            self.send_error(400, 'invalid poll interval')
                            return
                # TODO: only allow from http server?
                uid = self._ctx.set_uid(self.client_address[0], poll_interval,
                                        uid)
                if not uid:
                    self.send_error(400, 'invalid uid')
                    return
                self.send_response(200)
                self.send_header('Content-Type',
                                 'application/json; charset=utf-8')
                self.end_headers()
                self.wfile.write(json.dumps(uid).encode())
                return

            elif client_request == 'set_secret':
                secret = None
                uid = None
                for item in form.list:
                    if item.name == 'secret':
                        secret = item.value.strip()
                    elif item.name == 'uid':
                        uid = item.value.strip()
                if secret and uid == self._ctx.client_uid:
                    self._ctx.client_uid_time = time.time()
                    self._ctx.set_secret(secret)
                    self.send_response(200)
                    self.send_header('Content-Type',
                                     'application/json; charset=utf-8')
                    self.end_headers()
                    self.wfile.write(json.dumps(0).encode())
                else:
                    self.send_error(400)
                return

            elif client_request == 'add_node':
                host = ''
                port = None
                uid = None
                for item in form.list:
                    if item.name == 'host':
                        host = item.value
                    elif item.name == 'port':
                        try:
                            port = int(item.value)
                        except Exception:
                            port = None
                    elif item.name == 'uid':
                        uid = item.value.strip()
                if uid != self._ctx.client_uid:
                    self.send_error(400, 'invalid uid')
                    return
                if host and port:
                    ip_addr = dispy._node_ipaddr(host)
                    if ip_addr:
                        info = {'ip_addr': ip_addr, 'port': port}
                        Task(self._ctx.add_node, info)
                        self.send_response(200)
                        self.send_header('Content-Type',
                                         'application/json; charset=utf-8')
                        self.end_headers()
                        self.wfile.write(json.dumps(0).encode())
                    else:
                        self.send_error(400)
                return

            elif client_request == 'service_time':
                hosts = []
                svc_time = None
                control = None
                uid = None
                for item in form.list:
                    if item.name == 'hosts':
                        hosts = [str(host) for host in json.loads(item.value)]
                    elif item.name == 'control':
                        control = item.value
                    elif item.name == 'time':
                        svc_time = item.value.strip()
                    elif item.name == 'uid':
                        uid = item.value.strip()
                if uid != self._ctx.client_uid:
                    self.send_error(400, 'invalid uid')
                    return
                self._ctx.client_uid_time = time.time()
                for host in hosts:
                    Task(self._ctx.service_time, host, control, svc_time)
                self.send_response(200)
                self.send_header('Content-Type',
                                 'application/json; charset=utf-8')
                self.end_headers()
                self.wfile.write(json.dumps(0).encode())
                return

            elif client_request == 'set_cpus':
                hosts = []
                cpus = None
                uid = None
                for item in form.list:
                    if item.name == 'hosts':
                        hosts = [str(host) for host in json.loads(item.value)]
                        if not hosts:
                            self.send_error(400, 'invalid nodes')
                            return
                    elif item.name == 'cpus':
                        cpus = item.value
                        if cpus is not None:
                            try:
                                cpus = int(item.value)
                            except Exception:
                                self.send_error(400, 'invalid CPUs')
                                return
                    elif item.name == 'uid':
                        uid = item.value.strip()
                if uid != self._ctx.client_uid:
                    self.send_error(400, 'invalid uid')
                    return
                for host in hosts:
                    Task(self._ctx.set_cpus, host, cpus)
                self._ctx.client_uid_time = time.time()
                self.send_response(200)
                self.send_header('Content-Type',
                                 'application/json; charset=utf-8')
                self.end_headers()
                self.wfile.write(json.dumps(0).encode())
                return

            elif client_request == 'serve_clients':
                host = ''
                serve = None
                uid = None
                for item in form.list:
                    if item.name == 'host':
                        host = item.value
                    elif item.name == 'serve':
                        serve = item.value
                        try:
                            serve = int(serve)
                        except Exception:
                            pass
                    elif item.name == 'uid':
                        uid = item.value.strip()
                if (uid == self._ctx.client_uid and isinstance(serve, int)
                        and Task(self._ctx.serve_clients, host,
                                 serve).value() == 0):
                    self._ctx.client_uid_time = time.time()
                    self.send_response(200)
                    self.send_header('Content-Type',
                                     'application/json; charset=utf-8')
                    self.end_headers()
                    self.wfile.write(json.dumps(0).encode())
                else:
                    self.send_error(400)
                    return
                return

            elif client_request == 'poll_interval':
                uid = None
                interval = None
                for item in form.list:
                    if item.name == 'interval':
                        try:
                            interval = int(item.value)
                        except Exception:
                            if interval is not None:
                                logger.warning(
                                    '%s: invalid poll interval "%s" ignored',
                                    self._ctx.client_uid, item.value)
                                self.send_error(400)
                                return
                    elif item.name == 'uid':
                        uid = item.value.strip()
                if (uid == self._ctx.client_uid
                        and self._ctx.set_poll_interval(interval) == 0):
                    self._ctx.client_uid_time = time.time()
                    self.send_response(200)
                    self.send_header('Content-Type',
                                     'application/json; charset=utf-8')
                    self.end_headers()
                    self.wfile.write(json.dumps(0).encode())
                else:
                    self.send_error(400)
                return

            logger.debug('Bad POST request from %s: %s',
                         self.client_address[0], client_request)
            self.send_error(400)
            return
Ejemplo n.º 12
0
        def do_POST(self):
            try:
                form = cgi.FieldStorage(fp=self.rfile,
                                        headers=self.headers,
                                        environ={'REQUEST_METHOD': 'POST'})
                client_request = self.path[1:]
            except Exception:
                logger.debug('Ignoring invalid POST request from %s',
                             self.client_address[0])
                self.send_error(400)
                return

            if client_request == 'node_info':
                ip_addr = None
                for item in form.list:
                    if item.name == 'host':
                        # if it looks like IP address, skip resolving
                        if re.match(DispyHTTPServer._ClusterInfo.ip_re,
                                    item.value):
                            ip_addr = item.value
                        else:
                            ip_addr = dispy._node_ipaddr(item.value)
                        break
                    self._ctx._cluster_lock.acquire()
                    cluster_infos = [
                        (name, cluster_info)
                        for name, cluster_info in self._ctx._clusters.items()
                    ]
                    self._ctx._cluster_lock.release()
                    cluster_jobs = {}
                    node = None
                    show_args = self._ctx._show_args
                for name, cluster_info in cluster_infos:
                    cluster_node = cluster_info.status.get(ip_addr, None)
                    if not cluster_node:
                        cluster_jobs[name] = []
                        continue
                    if node:
                        node.jobs_done += cluster_node.jobs_done
                        node.cpu_time += cluster_node.cpu_time
                        node.update_time = max(node.update_time,
                                               cluster_node.update_time)
                        node.tx += cluster_node.tx
                        node.rx += cluster_node.rx
                    else:
                        node = copy.copy(cluster_node)
                        # jobs = cluster_info.cluster.node_jobs(ip_addr)
                    jobs = [
                        job for job in dict_iter(cluster_info.jobs, 'values')
                        if job.ip_addr == ip_addr
                    ]
                    # args and kwargs are sent as strings in Python,
                    # so an object's __str__ or __repr__ is used if provided;
                    # TODO: check job is in _ctx's jobs?
                    jobs = [{
                        'uid':
                        job._uid,
                        'job_id':
                        str(job.id),
                        'args':
                        ', '.join(str(arg)
                                  for arg in job._args) if show_args else '',
                        'kwargs':
                        ', '.join('%s=%s' % (key, val)
                                  for key, val in job._kwargs.items())
                        if show_args else '',
                        'start_time_ms':
                        int(1000 * job.start_time),
                        'cluster':
                        name
                    } for job in jobs]
                    cluster_jobs[name] = jobs
                    self.send_response(200)
                    self.send_header('Content-Type',
                                     'application/json; charset=utf-8')
                    self.end_headers()
                if node:
                    if node.avail_info:
                        node.avail_info = node.avail_info.__dict__
                        self.wfile.write(
                            json.dumps({
                                'node': node.__dict__,
                                'cluster_jobs': cluster_jobs
                            }).encode())
                return

            elif client_request == 'cancel_jobs':
                uids = []
                for item in form.list:
                    if item.name == 'uid':
                        try:
                            uids.append(int(item.value))
                        except ValueError:
                            logger.debug('Cancel job uid "%s" is invalid',
                                         item.value)

                self._ctx._cluster_lock.acquire()
                cluster_jobs = [
                    (cluster_info.cluster, cluster_info.jobs.get(uid, None))
                    for cluster_info in self._ctx._clusters.values()
                    for uid in uids
                ]
                self._ctx._cluster_lock.release()
                cancelled = []
                for cluster, job in cluster_jobs:
                    if not job:
                        continue
                    if cluster.cancel(job) == 0:
                        cancelled.append(job._uid)
                        self.send_response(200)
                        self.send_header('Content-Type',
                                         'application/json; charset=utf-8')
                        self.end_headers()
                        self.wfile.write(json.dumps(cancelled).encode())
                return

            elif client_request == 'add_node':
                node = {'host': '', 'port': None, 'cpus': 0, 'cluster': None}
                node_id = None
                cluster = None
                for item in form.list:
                    if item.name == 'host':
                        node['host'] = item.value
                    elif item.name == 'cluster':
                        node['cluster'] = item.value
                    elif item.name == 'port':
                        node['port'] = item.value
                    elif item.name == 'cpus':
                        try:
                            node['cpus'] = int(item.value)
                        except Exception:
                            pass
                    elif item.name == 'id':
                        node_id = item.value
                if node['host']:
                    self._ctx._cluster_lock.acquire()
                    clusters = [
                        cluster_info.cluster
                        for name, cluster_info in self._ctx._clusters.items()
                        if name == node['cluster'] or not node['cluster']
                    ]
                    self._ctx._cluster_lock.release()
                    for cluster in clusters:
                        cluster.allocate_node(node)
                        self.send_response(200)
                        self.send_header('Content-Type', 'text/html')
                        self.end_headers()
                        node['id'] = node_id
                        self.wfile.write(json.dumps(node).encode())
                    return

            elif (client_request == 'close_node'
                  or client_request == 'allocate_node'
                  or client_request == 'deallocate_node'):
                nodes = []
                cluster_infos = []
                resp = -1
                for item in form.list:
                    if item.name == 'cluster':
                        self._ctx._cluster_lock.acquire()
                        if item.value == '*':
                            cluster_infos = list(self._ctx._clusters.values())
                        else:
                            cluster_infos = [
                                self._ctx._clusters.get(item.value, None)
                            ]
                            if not cluster_infos[0]:
                                cluster_infos = []
                                self._ctx._cluster_lock.release()
                    elif item.name == 'nodes':
                        nodes = json.loads(item.value)
                        nodes = [str(node) for node in nodes]

                if cluster_infos and nodes:
                    resp = 0
                    for cluster_info in cluster_infos:
                        fn = getattr(cluster_info.cluster, client_request)
                        if fn:
                            for node in nodes:
                                resp |= fn(node)
                        else:
                            resp = -1
                            self.send_response(200)
                            self.send_header(
                                'Content-Type',
                                'application/json; charset=utf-8')
                            self.end_headers()
                            self.wfile.write(json.dumps(resp).encode())
                return

            elif client_request == 'update':
                for item in form.list:
                    if item.name == 'timeout':
                        try:
                            timeout = int(item.value)
                            if timeout < 1:
                                timeout = 0
                                self._ctx._poll_sec = timeout
                        except Exception:
                            logger.warning(
                                'HTTP client %s: invalid timeout "%s" ignored',
                                self.client_address[0], item.value)
                    elif item.name == 'show_job_args':
                        if item.value == 'true':
                            self._ctx._show_args = True
                        else:
                            self._ctx._show_args = False
                return

            elif client_request == 'set_cpus':
                node_cpus = {}
                for item in form.list:
                    self._ctx._cluster_lock.acquire()
                    for cluster_info in self._ctx._clusters.values():
                        node = cluster_info.status.get(item.name, None)
                        if node:
                            node_cpus[
                                item.
                                name] = cluster_info.cluster.set_node_cpus(
                                    item.name, item.value)
                            if node_cpus[item.name] >= 0:
                                break
                            self._ctx._cluster_lock.release()

                self.send_response(200)
                self.send_header('Content-Type',
                                 'application/json; charset=utf-8')
                self.end_headers()
                self.wfile.write(json.dumps(node_cpus).encode())
                return

            logger.debug('Bad POST request from %s: %s',
                         self.client_address[0], client_request)
            self.send_error(400)
            return
Ejemplo n.º 13
0
        def do_POST(self):
            try:
                form = cgi.FieldStorage(fp=self.rfile, headers=self.headers,
                                        environ={'REQUEST_METHOD': 'POST'})
                client_request = self.path[1:]
            except Exception:
                logger.debug('Ignoring invalid POST request from %s', self.client_address[0])
                self.send_error(400)
                return

            if client_request == 'update':
                uid = None
                for item in form.list:
                    if item.name == 'uid':
                        uid = item.value.strip()
                        break
                if uid != self._ctx.client_uid:
                    self.send_error(400, 'invalid uid')
                    return
                self._ctx.client_uid_time = time.time()
                self._ctx.lock.acquire()
                nodes = self.__class__.json_encode_nodes(self._ctx.updates)
                self._ctx.updates.clear()
                self._ctx.lock.release()
                self.send_response(200)
                self.send_header('Content-Type', 'application/json; charset=utf-8')
                self.end_headers()
                self.wfile.write(json.dumps(nodes).encode())
                return

            elif client_request == 'node_info':
                ip_addr = None
                uid = None
                for item in form.list:
                    if item.name == 'host':
                        # if it looks like IP address, skip resolving
                        if re.match(DispyAdminServer._NodeInfo.ip_re, item.value):
                            ip_addr = item.value
                        else:
                            ip_addr = dispy._node_ipaddr(item.value)
                    elif item.name == 'uid':
                        uid = item.value.strip()
                if uid != self._ctx.client_uid:
                    self.send_error(400, 'invalid uid')
                    return
                self._ctx.client_uid_time = time.time()
                node = self._ctx.nodes.get(ip_addr, None)
                self.send_response(200)
                self.send_header('Content-Type', 'application/json; charset=utf-8')
                self.end_headers()
                if node:
                    node = dict(node.__dict__)
                    node.pop('_priv', None)
                    if node['avail_info']:
                        node['avail_info'] = node['avail_info'].__dict__
                else:
                    node = {}
                self.wfile.write(json.dumps(node).encode())
                return

            elif client_request == 'status':
                uid = None
                for item in form.list:
                    if item.name == 'uid':
                        uid = item.value.strip()
                        break
                if uid != self._ctx.client_uid:
                    self.send_error(400, 'invalid uid')
                    return
                self._ctx.client_uid_time = time.time()
                self._ctx.lock.acquire()
                nodes = self.__class__.json_encode_nodes(self._ctx.nodes)
                self._ctx.lock.release()
                self.send_response(200)
                self.send_header('Content-Type', 'application/json; charset=utf-8')
                self.end_headers()
                self.wfile.write(json.dumps(nodes).encode())
                return

            elif client_request == 'get_uid':
                uid = None
                for item in form.list:
                    if item.name == 'uid':
                        uid = item.value.strip()
                        break
                now = time.time()
                # TODO: only allow from http server?
                if (self._ctx.client_uid and uid != self._ctx.client_uid and
                    ((now - self._ctx.client_uid_time) < 3600)):
                    self.send_error(400, 'invalid uid')
                    return
                if not uid:
                    uid = hashlib.sha1(os.urandom(20)).hexdigest()
                self._ctx.client_uid = uid
                self._ctx.client_uid_time = time.time()
                self.send_response(200)
                self.send_header('Content-Type', 'application/json; charset=utf-8')
                self.end_headers()
                self.wfile.write(json.dumps(uid).encode())
                return

            elif client_request == 'set_secret':
                secret = None
                uid = None
                for item in form.list:
                    if item.name == 'secret':
                        secret = item.value.strip()
                    elif item.name == 'uid':
                        uid = item.value.strip()
                if secret and uid == self._ctx.client_uid:
                    self._ctx.client_uid_time = time.time()
                    self._ctx.set_secret(secret)
                    self.send_response(200)
                    self.send_header('Content-Type', 'application/json; charset=utf-8')
                    self.end_headers()
                    self.wfile.write(json.dumps(0).encode())
                else:
                    self.send_error(400)
                return

            elif client_request == 'add_node':
                host = ''
                port = None
                uid = None
                for item in form.list:
                    if item.name == 'host':
                        host = item.value
                    elif item.name == 'port':
                        try:
                            port = int(item.value)
                        except Exception:
                            port = None
                    elif item.name == 'uid':
                        uid = item.value.strip()
                if uid != self._ctx.client_uid:
                    self.send_error(400, 'invalid uid')
                    return
                if host and port:
                    ip_addr = dispy._node_ipaddr(host)
                    if ip_addr:
                        info = {'ip_addr': ip_addr, 'port': port}
                        Task(self._ctx.add_node, info)
                        self.send_response(200)
                        self.send_header('Content-Type', 'application/json; charset=utf-8')
                        self.end_headers()
                        self.wfile.write(json.dumps(0).encode())
                    else:
                        self.send_error(400)
                return

            elif client_request == 'service_time':
                hosts = []
                svc_time = None
                control = None
                uid = None
                for item in form.list:
                    if item.name == 'hosts':
                        hosts = [str(host) for host in json.loads(item.value)]
                    elif item.name == 'control':
                        control = item.value
                    elif item.name == 'time':
                        svc_time = item.value
                    elif item.name == 'uid':
                        uid = item.value.strip()
                if uid != self._ctx.client_uid:
                    self.send_error(400, 'invalid uid')
                    return
                self._ctx.client_uid_time = time.time()
                for host in hosts:
                    Task(self._ctx.service_time, host, control, svc_time)
                self.send_response(200)
                self.send_header('Content-Type', 'application/json; charset=utf-8')
                self.end_headers()
                self.wfile.write(json.dumps(0).encode())
                return

            elif client_request == 'set_cpus':
                host = None
                cpus = None
                uid = None
                for item in form.list:
                    if item.name == 'host':
                        host = item.value.strip()
                    elif item.name == 'cpus':
                        try:
                            cpus = int(item.value.strip())
                        except Exception:
                            pass
                    elif item.name == 'uid':
                        uid = item.value.strip()
                if (host and cpus and uid == self._ctx.client_uid and
                    Task(self._ctx.set_cpus, host, cpus).value() == 0):
                    self._ctx.client_uid_time = time.time()
                    self.send_response(200)
                    self.send_header('Content-Type', 'application/json; charset=utf-8')
                    self.end_headers()
                    self.wfile.write(json.dumps(0).encode())
                else:
                    self.send_error(400)
                return

            elif client_request == 'serve_clients':
                host = ''
                serve = None
                uid = None
                for item in form.list:
                    if item.name == 'host':
                        host = item.value
                    elif item.name == 'serve':
                        serve = item.value
                        try:
                            serve = int(serve)
                        except Exception:
                            pass
                    elif item.name == 'uid':
                        uid = item.value.strip()
                if (uid == self._ctx.client_uid and isinstance(serve, int) and
                    Task(self._ctx.serve_clients, host, serve).value() == 0):
                    self._ctx.client_uid_time = time.time()
                    self.send_response(200)
                    self.send_header('Content-Type', 'application/json; charset=utf-8')
                    self.end_headers()
                    self.wfile.write(json.dumps(0).encode())
                else:
                    self.send_error(400)
                    return
                return

            elif client_request == 'poll_interval':
                uid = None
                interval = None
                for item in form.list:
                    if item.name == 'interval':
                        try:
                            interval = int(item.value)
                        except Exception:
                            if interval is not None:
                                logger.warning('%s: invalid poll interval "%s" ignored',
                                               self._ctx.client_uid, item.value)
                                self.send_error(400)
                                return
                    elif item.name == 'uid':
                        uid = item.value.strip()
                if (uid == self._ctx.client_uid and self._ctx.set_poll_interval(interval) == 0):
                    self._ctx.client_uid_time = time.time()
                    self.send_response(200)
                    self.send_header('Content-Type', 'application/json; charset=utf-8')
                    self.end_headers()
                    self.wfile.write(json.dumps(0).encode())
                else:
                    self.send_error(400)
                return

            logger.debug('Bad POST request from %s: %s', self.client_address[0], client_request)
            self.send_error(400)
            return
Ejemplo n.º 14
0
    def relay_pings(self, ip_addr='', netmask=None, node_port=51348,
                    scheduler_node=None, scheduler_port=51347):
        netaddr = None
        if not netmask:
            try:
                ip_addr, bits = ip_addr.split('/')
                socket.inet_aton(ip_addr)
                netmask = (0xffffffff << (32 - int(bits))) & 0xffffffff
                netaddr = (struct.unpack('>L', socket.inet_aton(ip_addr))[0]) & netmask
            except:
                netmask = '255.255.255.255'
        if ip_addr:
            socket.inet_aton(ip_addr)
        else:
            ip_addr = socket.gethostbyname(socket.gethostname())
        if not netaddr and netmask:
            try:
                if isinstance(netmask, str):
                    netmask = struct.unpack('>L', socket.inet_aton(netmask))[0]
                else:
                    assert isinstance(netmask, int)
                assert netmask > 0
                netaddr = (struct.unpack('>L', socket.inet_aton(ip_addr))[0]) & netmask
            except:
                logger.warning('Invalid netmask')

        try:
            socket.inet_ntoa(struct.pack('>L', netaddr))
            socket.inet_ntoa(struct.pack('>L', netmask))
        except:
            netaddr = netmask = None

        scheduler_version = _dispy_version

        bc_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        bc_sock.bind(('', 0))
        bc_sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)

        scheduler_ip_addr = _node_ipaddr(scheduler_node)
        if scheduler_ip_addr and scheduler_port:
            relay_request = serialize({'ip_addr':scheduler_ip_addr, 'port':scheduler_port,
                                       'version':_dispy_version, 'sign':None})
            bc_sock.sendto('PING:%s' % relay_request, ('<broadcast>', node_port))

        node_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        node_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        node_sock.bind(('', node_port))
        sched_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        sched_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        sched_sock.bind(('', scheduler_port))
        logger.info('Listening on %s:%s/%s', ip_addr, node_port, scheduler_port)
        while True:
            ready = select.select([node_sock, sched_sock], [], [])[0]
            for sock in ready:
                if sock == node_sock:
                    msg, addr = node_sock.recvfrom(1024)
                    if not msg.startswith('PING:'):
                        logger.debug('Ignoring message "%s" from %s',
                                     msg[:min(len(msg), 5)], addr[0])
                        continue
                    if netaddr and (struct.unpack('>L', socket.inet_aton(addr[0]))[0] & netmask) == netaddr:
                        logger.debug('Ignoring own ping (from %s)', addr[0])
                        continue
                    logger.debug('Ping message from %s (%s)', addr[0], addr[1])
                    try:
                        info = unserialize(msg[len('PING:'):])
                        scheduler_ip_addr = info['ip_addr']
                        scheduler_port = info['port']
                        assert info['version'] == _dispy_version
                        # scheduler_sign = info['sign']
                        assert isinstance(scheduler_port, int)
                    except:
                        logger.debug('Ignoring ping message from %s (%s)', addr[0], addr[1])
                        logger.debug(traceback.format_exc())
                        continue
                    logger.debug('relaying ping from %s / %s' % (info['ip_addr'], addr[0]))
                    if scheduler_ip_addr is None:
                        info['ip_addr'] = scheduler_ip_addr = addr[0]
                    relay_request = serialize(info)
                    bc_sock.sendto('PING:%s' % relay_request, ('<broadcast>', node_port))
                else:
                    assert sock == sched_sock
                    msg, addr = sched_sock.recvfrom(1024)
                    if msg.startswith('PING:') and scheduler_ip_addr and scheduler_port:
                        try:
                            info = unserialize(msg[len('PONG:'):])
                            assert info['version'] == _dispy_version
                            assert isinstance(info['ip_addr'], str)
                            assert isinstance(info['port'], int)
                            # assert isinstance(info['cpus'], int)
                            info['scheduler_ip_addr'] = scheduler_ip_addr
                            relay_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
                            relay_sock.sendto('PING:' + serialize(info),
                                              (scheduler_ip_addr, scheduler_port))
                            relay_sock.close()
                        except:
                            logger.debug(traceback.format_exc())
                            # raise
                            logger.debug('Ignoring ping message from %s (%s)', addr[0], addr[1])
Ejemplo n.º 15
0
    def __init__(self, cpus, ip_addr=None, ext_ip_addr=None, node_port=None,
                 scheduler_node=None, scheduler_port=None,
                 dest_path_prefix='', secret='', keyfile=None, certfile=None,
                 max_file_size=None, zombie_interval=60):
        assert 0 < cpus <= multiprocessing.cpu_count()
        self.cpus = cpus
        if ip_addr:
            ip_addr = _node_ipaddr(ip_addr)
            if not ip_addr:
                raise Exception('invalid ip_addr')
        else:
            self.name = socket.gethostname()
            ip_addr = socket.gethostbyname(self.name)
        if ext_ip_addr:
            ext_ip_addr = _node_ipaddr(ext_ip_addr)
            if not ext_ip_addr:
                raise Exception('invalid ext_ip_addr')
        else:
            ext_ip_addr = ip_addr
        try:
            self.name = socket.gethostbyaddr(ext_ip_addr)[0]
        except:
            self.name = socket.gethostname()
        if not node_port:
            node_port = 51348
        if not scheduler_port:
            scheduler_port = 51347

        self.ip_addr = ip_addr
        self.ext_ip_addr = ext_ip_addr
        self.scheduler_port = scheduler_port
        self.pulse_interval = None
        self.keyfile = keyfile
        self.certfile = certfile
        if self.keyfile:
            self.keyfile = os.path.abspath(self.keyfile)
        if self.certfile:
            self.certfile = os.path.abspath(self.certfile)

        self.asyncoro = AsynCoro()

        self.tcp_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        if self.certfile:
            self.tcp_sock = ssl.wrap_socket(self.tcp_sock, keyfile=self.keyfile,
                                            certfile=self.certfile)
        self.tcp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.tcp_sock.bind((self.ip_addr, node_port))
        self.address = self.tcp_sock.getsockname()
        self.tcp_sock.listen(30)

        if dest_path_prefix:
            self.dest_path_prefix = dest_path_prefix.strip().rstrip(os.sep)
        else:
            self.dest_path_prefix = os.path.join(os.sep, 'tmp', 'dispy')
        if not os.path.isdir(self.dest_path_prefix):
            os.makedirs(self.dest_path_prefix)
            os.chmod(self.dest_path_prefix, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
        if max_file_size is None:
            max_file_size = MaxFileSize
        self.max_file_size = max_file_size

        self.avail_cpus = self.cpus
        self.computations = {}
        self.scheduler_ip_addr = None
        self.file_uses = {}
        self.job_infos = {}
        self.lock = asyncoro.Lock()
        self.terminate = False
        self.signature = os.urandom(20).encode('hex')
        self.auth_code = hashlib.sha1(self.signature + secret).hexdigest()
        self.zombie_interval = 60 * zombie_interval

        logger.debug('auth_code for %s: %s', ip_addr, self.auth_code)

        self.udp_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.udp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.udp_sock.bind(('', node_port))
        logger.info('serving %s cpus at %s:%s', self.cpus, self.ip_addr, node_port)
        logger.debug('tcp server at %s:%s', self.address[0], self.address[1])
        self.udp_sock = AsynCoroSocket(self.udp_sock, blocking=False)

        scheduler_ip_addr = _node_ipaddr(scheduler_node)

        self.reply_Q = multiprocessing.Queue()
        self.reply_Q_thread = threading.Thread(target=self.__reply_Q)
        self.reply_Q_thread.start()

        self.timer_coro = Coro(self.timer_task)
        # self.tcp_coro = Coro(self.tcp_server)
        self.udp_coro = Coro(self.udp_server, scheduler_ip_addr)
Ejemplo n.º 16
0
    def relay_pings(self,
                    ip_addr='',
                    netmask=None,
                    node_port=51348,
                    scheduler_node=None,
                    scheduler_port=51347):
        netaddr = None
        if not netmask:
            try:
                ip_addr, bits = ip_addr.split('/')
                socket.inet_aton(ip_addr)
                netmask = (0xffffffff << (32 - int(bits))) & 0xffffffff
                netaddr = (struct.unpack(
                    '>L', socket.inet_aton(ip_addr))[0]) & netmask
            except:
                netmask = '255.255.255.255'
        if ip_addr:
            socket.inet_aton(ip_addr)
        else:
            ip_addr = socket.gethostbyname(socket.gethostname())
        if not netaddr and netmask:
            try:
                if isinstance(netmask, str):
                    netmask = struct.unpack('>L', socket.inet_aton(netmask))[0]
                else:
                    assert isinstance(netmask, int)
                assert netmask > 0
                netaddr = (struct.unpack(
                    '>L', socket.inet_aton(ip_addr))[0]) & netmask
            except:
                logger.warning('Invalid netmask')

        try:
            socket.inet_ntoa(struct.pack('>L', netaddr))
            socket.inet_ntoa(struct.pack('>L', netmask))
        except:
            netaddr = netmask = None

        bc_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        bc_sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)

        scheduler_ip_addrs = list(
            filter(lambda ip: bool(ip), [_node_ipaddr(scheduler_node)]))
        if scheduler_ip_addrs and scheduler_port:
            relay_request = {
                'ip_addrs': scheduler_ip_addrs,
                'port': scheduler_port,
                'version': _dispy_version,
                'sign': None
            }
            bc_sock.sendto(b'PING:' + serialize(relay_request),
                           ('<broadcast>', node_port))
        bc_sock.close()

        node_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        node_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        node_sock.bind(('', node_port))
        sched_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        sched_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        sched_sock.bind(('', scheduler_port))
        logger.info('Listening on %s:%s/%s', ip_addr, node_port,
                    scheduler_port)
        while True:
            ready = select.select([node_sock, sched_sock], [], [])[0]
            for sock in ready:
                if sock == node_sock:
                    msg, addr = node_sock.recvfrom(1024)
                    if not msg.startswith(b'PING:'):
                        logger.debug('Ignoring message "%s" from %s',
                                     msg[:min(len(msg), 5)], addr[0])
                        continue
                    if netaddr and \
                       (struct.unpack('>L', socket.inet_aton(addr[0]))[0] & netmask) == netaddr:
                        logger.debug('Ignoring ping back (from %s)', addr[0])
                        continue
                    logger.debug('Ping message from %s (%s)', addr[0], addr[1])
                    try:
                        info = unserialize(msg[len(b'PING:'):])
                        if info['version'] != _dispy_version:
                            logger.warning(
                                'Ignoring %s due to version mismatch: %s / %s',
                                info['ip_addrs'], info['version'],
                                _dispy_version)
                            continue
                        scheduler_ip_addrs = info['ip_addrs'] + [addr[0]]
                        scheduler_port = info['port']
                    except:
                        logger.debug('Ignoring ping message from %s (%s)',
                                     addr[0], addr[1])
                        logger.debug(traceback.format_exc())
                        continue
                    logger.debug('relaying ping from %s / %s' %
                                 (info['ip_addrs'], addr[0]))
                    bc_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
                    bc_sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST,
                                       1)
                    bc_sock.sendto(b'PING:' + serialize(info),
                                   ('<broadcast>', node_port))
                    bc_sock.close()
                else:
                    assert sock == sched_sock
                    msg, addr = sched_sock.recvfrom(1024)
                    if msg.startswith(
                            b'PING:'
                    ) and scheduler_ip_addrs and scheduler_port:
                        try:
                            info = unserialize(msg[len(b'PING:'):])
                            if netaddr and info.get('scheduler_ip_addr', None) and \
                               (struct.unpack('>L', socket.inet_aton(info['scheduler_ip_addr']))[0] & netmask) == netaddr:
                                logger.debug('Ignoring ping back (from %s)' %
                                             addr[0])
                                continue
                            assert info['version'] == _dispy_version
                            # assert isinstance(info['cpus'], int)
                            msg = {
                                'ip_addrs': scheduler_ip_addrs,
                                'port': scheduler_port,
                                'version': _dispy_version
                            }
                            relay_sock = socket.socket(socket.AF_INET,
                                                       socket.SOCK_DGRAM)
                            relay_sock.sendto(b'PING:' + serialize(msg),
                                              (info['ip_addr'], info['port']))
                            relay_sock.close()
                        except:
                            logger.debug(traceback.format_exc())
                            # raise
                            logger.debug('Ignoring ping message from %s (%s)',
                                         addr[0], addr[1])
Ejemplo n.º 17
0
    def __init__(self,
                 ip_addrs=[],
                 node_port=51348,
                 relay_port=0,
                 scheduler_nodes=[],
                 scheduler_port=51347,
                 ipv4_udp_multicast=False,
                 secret='',
                 certfile=None,
                 keyfile=None):
        addrinfos = []
        if not ip_addrs:
            ip_addrs = [None]
        for i in range(len(ip_addrs)):
            ip_addr = ip_addrs[i]
            addrinfo = dispy.host_addrinfo(host=ip_addr)
            if not addrinfo:
                logger.warning('Ignoring invalid ip_addr %s', ip_addr)
                continue
            addrinfos.append(addrinfo)

        self.node_port = node_port
        if not relay_port:
            relay_port = node_port
        self.relay_port = relay_port
        self.ipv4_udp_multicast = bool(ipv4_udp_multicast)
        self.ip_addrs = set()
        self.scheduler_ip_addr = None
        self.scheduler_port = scheduler_port
        self.secret = secret
        if certfile:
            self.certfile = os.path.abspath(certfile)
        else:
            self.certfile = None
        if keyfile:
            self.keyfile = os.path.abspath(keyfile)
        else:
            self.keyfile = None

        udp_addrinfos = {}
        for addrinfo in addrinfos:
            self.ip_addrs.add(addrinfo.ip)
            if addrinfo.family == socket.AF_INET and self.ipv4_udp_multicast:
                addrinfo.broadcast = dispy.IPV4_MULTICAST_GROUP
            Task(self.relay_tcp_proc, addrinfo)
            if os.name == 'nt':
                bind_addr = addrinfo.ip
            elif sys.platform == 'darwin':
                if addrinfo.family == socket.AF_INET and (
                        not self.ipv4_udp_multicast):
                    bind_addr = ''
                else:
                    bind_addr = addrinfo.broadcast
            else:
                bind_addr = addrinfo.broadcast
            udp_addrinfos[bind_addr] = addrinfo

        scheduler_ip_addrs = []
        for addr in scheduler_nodes:
            addr = dispy._node_ipaddr(addr)
            if addr:
                scheduler_ip_addrs.append(addr)

        for bind_addr, addrinfo in udp_addrinfos.items():
            Task(self.relay_udp_proc, bind_addr, addrinfo)
            Task(self.sched_udp_proc, bind_addr, addrinfo)
            for addr in scheduler_ip_addrs:
                msg = {
                    'version': __version__,
                    'ip_addrs': [addr],
                    'port': self.scheduler_port,
                    'sign': None
                }
                Task(self.verify_broadcast, addrinfo, msg)

        logger.info('version %s started', dispy._dispy_version)
Ejemplo n.º 18
0
    def relay_pings(self, ip_addr='', netmask=None, node_port=51348,
                    scheduler_node=None, scheduler_port=51347):
        netaddr = None
        if not netmask:
            try:
                ip_addr, bits = ip_addr.split('/')
                socket.inet_aton(ip_addr)
                netmask = (0xffffffff << (32 - int(bits))) & 0xffffffff
                netaddr = (struct.unpack('>L', socket.inet_aton(ip_addr))[0]) & netmask
            except:
                netmask = '255.255.255.255'
        if ip_addr:
            socket.inet_aton(ip_addr)
        else:
            ip_addr = socket.gethostbyname(socket.gethostname())
        if not netaddr and netmask:
            try:
                if isinstance(netmask, str):
                    netmask = struct.unpack('>L', socket.inet_aton(netmask))[0]
                else:
                    assert isinstance(netmask, int)
                assert netmask > 0
                netaddr = (struct.unpack('>L', socket.inet_aton(ip_addr))[0]) & netmask
            except:
                logger.warning('Invalid netmask')

        try:
            socket.inet_ntoa(struct.pack('>L', netaddr))
            socket.inet_ntoa(struct.pack('>L', netmask))
        except:
            netaddr = netmask = None

        scheduler_version = _dispy_version

        bc_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        bc_sock.bind(('', 0))
        bc_sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)

        scheduler_ip_addr = _node_ipaddr(scheduler_node)
        if scheduler_ip_addr and scheduler_port:
            relay_request = serialize({'scheduler_ip_addr':scheduler_ip_addr,
                                       'scheduler_port':scheduler_port,
                                       'version':scheduler_version})
            bc_sock.sendto('PING:%s' % relay_request, ('<broadcast>', node_port))

        ping_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        ping_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        ping_sock.bind(('', node_port))
        pong_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        pong_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        pong_sock.bind(('', scheduler_port))
        logger.info('Listening on %s:%s', ip_addr, node_port)
        last_ping = 0
        while True:
            ready = select.select([ping_sock, pong_sock], [], [])[0]
            for sock in ready:
                if sock == ping_sock:
                    msg, addr = ping_sock.recvfrom(1024)
                    if not msg.startswith('PING:'):
                        logger.debug('Ignoring message "%s" from %s',
                                     msg[:max(len(msg), 5)], addr[0])
                        continue
                    if netaddr and (struct.unpack('>L', socket.inet_aton(addr[0]))[0] & netmask) == netaddr:
                        logger.debug('Ignoring own ping (from %s)', addr[0])
                        continue
                    if (time.time() - last_ping) < 10:
                        logger.warning('Ignoring ping (from %s) for 10 more seconds', addr[0])
                        time.sleep(10)
                    last_ping = time.time()
                    logger.debug('Ping message from %s (%s)', addr[0], addr[1])
                    try:
                        data = unserialize(msg[len('PING:'):])
                        scheduler_ip_addr = data['scheduler_ip_addr']
                        scheduler_port = data['scheduler_port']
                        scheduler_version = data['version']
                        assert isinstance(scheduler_ip_addr, str)
                        assert isinstance(scheduler_port, int)
                    except:
                        logger.debug('Ignoring ping message from %s (%s)', addr[0], addr[1])
                        continue
                    relay_request = serialize({'scheduler_ip_addr':scheduler_ip_addr,
                                               'scheduler_port':scheduler_port,
                                               'version':scheduler_version})
                    bc_sock.sendto('PING:%s' % relay_request, ('<broadcast>', node_port))
                else:
                    assert sock == pong_sock
                    msg, addr = pong_sock.recvfrom(1024)
                    if not msg.startswith('PONG:'):
                        logger.debug('Ignoring pong message "%s" from %s',
                                     msg[:max(len(msg), 5)], addr[0])
                        continue
                    # if netaddr and (struct.unpack('>L', socket.inet_aton(addr[0]))[0] & netmask) == netaddr:
                    #     logger.debug('Ignoring own pong (from %s)', addr[0])
                    #     continue
                    if not (scheduler_ip_addr and scheduler_port):
                        logger.debug('Ignoring pong message from %s', str(addr))
                        continue
                    logger.debug('Pong message from %s (%s)', addr[0], addr[1])
                    try:
                        pong = unserialize(msg[len('PONG:'):])
                        assert isinstance(pong['host'], str)
                        assert isinstance(pong['port'], int)
                        assert isinstance(pong['cpus'], int)
                        relay_request = serialize({'scheduler_ip_addr':scheduler_ip_addr,
                                                   'scheduler_port':scheduler_port,
                                                   'version':scheduler_version})
                        relay_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
                        relay_sock.sendto('PING:%s' % relay_request,
                                          (pong['host'], node_port))
                        relay_sock.close()
                    except:
                        # raise
                        logger.debug('Ignoring pong message from %s (%s)', addr[0], addr[1])