Пример #1
0
class BenchmarkWorker:

    def __init__(self,server,port,secret,requests,concurrency,username,password, verb=False,timeout=600,rate=1000):
        logname = "/tmp/trbctl-worker-{}.log".format(os.environ.get("LOGID",0))
        log.startLogging(open(logname,'w'))
        self.timeout = timeout
        self.pusher = ZmqPushConnection(ZmqFactory(), ZmqEndpoint('connect', 'ipc:///tmp/toughbt-message'))
        self.pusher.push("write worker %s log into %s" % (os.getpid(),logname))
        log.msg("init BenchmarkWorker pusher : %s " % repr(self.pusher))
        # define client cycle list
        raddict = dictionary.Dictionary(os.path.join(os.path.dirname(__file__),"dictionary"))
        new_cli = lambda : RadAuthClient(str(secret), raddict, server,port=port,debug=verb,stat_push=self.pusher)
        clis = itertools.cycle([new_cli() for c in range(concurrency)])

        # send radius message
        send = lambda:next(clis).sendAuth(**{'User-Name' : username,'User-Password':password})
        
        send_rate = 1.0/rate
        send_delay = send_rate
        for i in xrange(requests):
            reactor.callLater(send_delay,send)
            send_delay += send_rate

        reactor.callLater(self.timeout,self.on_timeout)



    def on_timeout(self):
        self.pusher.push("logger: BenchmarkWorker timeout, running times: %s" % self.timeout)
        reactor.stop()
Пример #2
0
class ZeroMQDelegatorService(Service):
    """
    This is an outbound PUSH connection to the web API ZeroMQRepeaterService
    that allows a worker to delegate any sub-links it finds (instead of taking
    a detour to crawl them on their own).
    """

    def __init__(self):
        self.conn = None

    def startService(self):
        factory = ZmqFactory()
        log.msg("Delegator connecting to repeater: %s" % ZMQ_REPEATER)
        endpoint = ZmqEndpoint('connect', ZMQ_REPEATER)
        self.conn = ZmqPushConnection(factory, endpoint)

    def send_message(self, message):
        """
        Matches the signature of ZeroMQBroadcastService so we can use them
        interchangably in the job queue code.

        :param str message: A JSON crawler message.
        """

        log.msg("Delegating job: %s" % message)
        self.conn.push(message)
Пример #3
0
class TwistedZmqClient(object):
    def __init__(self, service):
        zf = ZmqFactory()
        e = ZmqEndpoint('connect', 'tcp://%s:%s' % (service.host, service.port))
        self.conn = ZmqPushConnection(zf, e)
        
    def send(self, msg):
        self.conn.push(msg)
Пример #4
0
class TwistedZmqClient(object):
    def __init__(self, service):
        zf = ZmqFactory()
        e = ZmqEndpoint('connect', 'tcp://%s:%s' % (service.host, service.port))
        self.conn = ZmqPushConnection(zf, e)
        
    def send(self, clientid, levellist):
        self.conn.push(json.dumps({'clientid': clientid, 'levellist': levellist}))
Пример #5
0
class ZeroMQBroadcastService(Service):
    """
    This is used by the HTTP API to hand crawl jobs off to the crawler pool.
    """

    def __init__(self):
        self.conn = None

    def startService(self):
        factory = ZmqFactory()
        bind_point = 'tcp://0.0.0.0:8050'
        log.msg("Broadcaster binding on: %s" % bind_point)
        endpoint = ZmqEndpoint('bind', bind_point)
        self.conn = ZmqPushConnection(factory, endpoint)

    def send_message(self, message):
        log.msg("Sent crawl announcement: %s" % message)
        self.conn.push(message)
Пример #6
0
class RADIUSAuthWorker(object):

    def __init__(self, config, dbengine):
        self.config = config
        self.dict = dictionary.Dictionary(
            os.path.join(os.path.dirname(toughradius.__file__), 'dictionarys/dictionary'))
        self.db_engine = dbengine or get_engine(config)
        self.aes = utils.AESCipher(key=self.config.system.secret)
        self.mcache = mcache.Mcache()
        self.pusher = ZmqPushConnection(ZmqFactory(), ZmqEndpoint('connect', 'ipc:///tmp/radiusd-auth-result'))
        self.stat_pusher = ZmqPushConnection(ZmqFactory(), ZmqEndpoint('connect', 'ipc:///tmp/radiusd-stat-task'))
        self.puller = ZmqPullConnection(ZmqFactory(), ZmqEndpoint('connect', 'ipc:///tmp/radiusd-auth-message'))
        self.puller.onPull = self.process
        logger.info("init auth worker pusher : %s " % (self.pusher))
        logger.info("init auth worker puller : %s " % (self.puller))
        logger.info("init auth stat pusher : %s " % (self.stat_pusher))

    def find_nas(self,ip_addr):
        def fetch_result():
            table = models.TrBas.__table__
            with self.db_engine.begin() as conn:
                return conn.execute(table.select().where(table.c.ip_addr==ip_addr)).first()
        return self.mcache.aget(bas_cache_key(ip_addr),fetch_result, expire=600)

    def do_stat(self,code):
        try:
            stat_msg = []
            if code == packet.AccessRequest:
                stat_msg.append('auth_req')
            elif code == packet.AccessAccept:
                stat_msg.append('auth_accept')
            elif  code == packet.AccessReject:
                stat_msg.append('auth_reject')
            else:
                stat_msg = ['auth_drop']
            self.stat_pusher.push(msgpack.packb(stat_msg))
        except:
            pass

    def process(self, message):
        datagram, host, port =  msgpack.unpackb(message[0])
        reply = self.processAuth(datagram, host, port)
        if not reply:
            return
        self.do_stat(reply.code)
        logger.info("[Radiusd] :: Send radius response: %s" % repr(reply))
        if self.config.system.debug:
            logger.debug(reply.format_str())
        self.pusher.push(msgpack.packb([reply.ReplyPacket(),host,port]))

    def createAuthPacket(self, **kwargs):
        vendor_id = kwargs.pop('vendor_id',0)
        auth_message = message.AuthMessage(**kwargs)
        auth_message.vendor_id = vendor_id
        auth_message = mac_parse.process(auth_message)
        auth_message = vlan_parse.process(auth_message)
        return auth_message

    def processAuth(self, datagram, host, port):
        try:
            bas = self.find_nas(host)
            if not bas:
                raise PacketError('[Radiusd] :: Dropping packet from unknown host %s' % host)

            secret, vendor_id = bas['bas_secret'], bas['vendor_id']
            req = self.createAuthPacket(packet=datagram, 
                dict=self.dict, secret=six.b(str(secret)),vendor_id=vendor_id)

            self.do_stat(req.code)

            logger.info("[Radiusd] :: Received radius request: %s" % (repr(req)))
            if self.config.system.debug:
                logger.debug(req.format_str())

            if req.code != packet.AccessRequest:
                raise PacketError('non-AccessRequest packet on authentication socket')

            reply = req.CreateReply()
            reply.vendor_id = req.vendor_id

            aaa_request = dict(
                account_number=req.get_user_name(),
                domain=req.get_domain(),
                macaddr=req.client_mac,
                nasaddr=req.get_nas_addr(),
                vlanid1=req.vlanid1,
                vlanid2=req.vlanid2
            )

            auth_resp = RadiusAuth(self.db_engine,self.mcache,self.aes,aaa_request).authorize()

            if auth_resp['code'] > 0:
                reply['Reply-Message'] = auth_resp['msg']
                reply.code = packet.AccessReject
                return reply

            if 'bypass' in auth_resp and int(auth_resp['bypass']) == 0:
                is_pwd_ok = True
            else:
                is_pwd_ok = req.is_valid_pwd(auth_resp.get('passwd'))

            if not is_pwd_ok:
                reply['Reply-Message'] =  "password not match"
                reply.code = packet.AccessReject
                return reply
            else:
                if u"input_rate" in auth_resp and u"output_rate" in auth_resp:
                    reply = rate_process.process(
                        reply, input_rate=auth_resp['input_rate'], output_rate=auth_resp['output_rate'])

                attrs = auth_resp.get("attrs") or {}
                for attr_name in attrs:
                    try:
                        # todo: May have a type matching problem
                        reply.AddAttribute(utils.safestr(attr_name), attrs[attr_name])
                    except Exception as err:
                        errstr = "RadiusError:current radius cannot support attribute {0},{1}".format(
                            attr_name,utils.safestr(err.message))
                        logger.error(errstr)

                for attr, attr_val in req.resp_attrs.iteritems():
                    reply[attr] = attr_val

            reply['Reply-Message'] = 'success!'
            reply.code = packet.AccessAccept
            if not req.VerifyReply(reply):
                raise PacketError('VerifyReply error')
            return reply
        except Exception as err:
            self.do_stat(0)
            errstr = 'RadiusError:Dropping invalid auth packet from {0} {1},{2}'.format(
                host, port, utils.safeunicode(err))
            logger.error(errstr)
            import traceback
            traceback.print_exc()
Пример #7
0
class RADIUSAcctWorker(object):

    def __init__(self, config, dbengine):
        self.config = config
        self.dict = dictionary.Dictionary(
            os.path.join(os.path.dirname(toughradius.__file__), 'dictionarys/dictionary'))
        self.db_engine = dbengine or get_engine(config)
        self.mcache = mcache.Mcache()
        self.pusher = ZmqPushConnection(ZmqFactory(), ZmqEndpoint('connect', 'ipc:///tmp/radiusd-acct-result'))
        self.stat_pusher = ZmqPushConnection(ZmqFactory(), ZmqEndpoint('connect', 'ipc:///tmp/radiusd-stat-task'))
        self.puller = ZmqPullConnection(ZmqFactory(), ZmqEndpoint('connect', 'ipc:///tmp/radiusd-acct-message'))
        self.puller.onPull = self.process
        logger.info("init acct worker pusher : %s " % (self.pusher))
        logger.info("init acct worker puller : %s " % (self.puller))
        logger.info("init auth stat pusher : %s " % (self.stat_pusher))
        self.acct_class = {
            STATUS_TYPE_START: RadiusAcctStart,
            STATUS_TYPE_STOP: RadiusAcctStop,
            STATUS_TYPE_UPDATE: RadiusAcctUpdate,
            STATUS_TYPE_ACCT_ON: RadiusAcctOnoff,
            STATUS_TYPE_ACCT_OFF: RadiusAcctOnoff
        }


    def find_nas(self,ip_addr):
        def fetch_result():
            table = models.TrBas.__table__
            with self.db_engine.begin() as conn:
                return conn.execute(table.select().where(table.c.ip_addr==ip_addr)).first()
        return self.mcache.aget(bas_cache_key(ip_addr),fetch_result, expire=600)

    def do_stat(self,code, status_type=0):
        try:
            stat_msg = ['acct_drop']
            if code  in (4,5):
                stat_msg = []
                if code == packet.AccountingRequest:
                    stat_msg.append('acct_req')
                elif code == packet.AccountingResponse:
                    stat_msg.append('acct_resp')

                if status_type == 1:
                    stat_msg.append('acct_start')
                elif status_type == 2:
                    stat_msg.append('acct_stop')        
                elif status_type == 3:
                    stat_msg.append('acct_update')        
                elif status_type == 7:
                    stat_msg.append('acct_on')        
                elif status_type == 8:
                    stat_msg.append('acct_off')
            self.stat_pusher.push(msgpack.packb(stat_msg))
        except:
            pass

    def process(self, message):
        datagram, host, port =  msgpack.unpackb(message[0])
        self.processAcct(datagram, host, port)
        
    def createAcctPacket(self, **kwargs):
        vendor_id = 0
        if 'vendor_id' in kwargs:
            vendor_id = kwargs.pop('vendor_id')
        acct_message = message.AcctMessage(**kwargs)
        acct_message.vendor_id = vendor_id
        acct_message = mac_parse.process(acct_message)
        acct_message = vlan_parse.process(acct_message)
        return acct_message

    def processAcct(self, datagram, host, port):
        try:
            bas = self.find_nas(host)
            if not bas:
                raise PacketError('[Radiusd] :: Dropping packet from unknown host %s' % host)

            secret, vendor_id = bas['bas_secret'], bas['vendor_id']
            req = self.createAcctPacket(packet=datagram, 
                dict=self.dict, secret=six.b(str(secret)),vendor_id=vendor_id)

            self.do_stat(req.code, req.get_acct_status_type())

            logger.info("[Radiusd] :: Received radius request: %s" % (repr(req)))
            if self.config.system.debug:
                logger.debug(req.format_str())

            if req.code != packet.AccountingRequest:
                raise PacketError('non-AccountingRequest packet on authentication socket')

            if not req.VerifyAcctRequest():
                raise PacketError('VerifyAcctRequest error')

            reply = req.CreateReply()
            self.pusher.push(msgpack.packb([reply.ReplyPacket(),host,port]))
            self.do_stat(reply.code)
            logger.info("[Radiusd] :: Send radius response: %s" % repr(reply))
            if self.config.system.debug:
                logger.debug(reply.format_str())

            status_type = req.get_acct_status_type()
            if status_type in self.acct_class:
                acct_func = self.acct_class[status_type](
                        self.db_engine,self.mcache,None,req.get_ticket()).acctounting
                reactor.callLater(0.1,acct_func)
            else:
                logger.error('status_type <%s> not support' % status_type)
        except Exception as err:
            self.do_stat(0)
            errstr = 'RadiusError:Dropping invalid acct packet from {0} {1},{2}'.format(
                host, port, utils.safeunicode(err))
            logger.error(errstr)
            import traceback
            traceback.print_exc()
Пример #8
0
class RADIUSAcctWorker(object):
    def __init__(self, config, dbengine, radcache=None):
        self.config = config
        self.dict = dictionary.Dictionary(
            os.path.join(os.path.dirname(toughradius.__file__),
                         'dictionarys/dictionary'))
        self.db_engine = dbengine or get_engine(config)
        self.mcache = radcache
        self.pusher = ZmqPushConnection(
            ZmqFactory(),
            ZmqEndpoint('connect', 'ipc:///tmp/radiusd-acct-result'))
        self.stat_pusher = ZmqPushConnection(
            ZmqFactory(), ZmqEndpoint('connect',
                                      'ipc:///tmp/radiusd-stat-task'))
        self.puller = ZmqPullConnection(
            ZmqFactory(),
            ZmqEndpoint('connect', 'ipc:///tmp/radiusd-acct-message'))
        self.puller.onPull = self.process
        logger.info("init acct worker pusher : %s " % (self.pusher))
        logger.info("init acct worker puller : %s " % (self.puller))
        logger.info("init auth stat pusher : %s " % (self.stat_pusher))
        self.acct_class = {
            STATUS_TYPE_START: RadiusAcctStart,
            STATUS_TYPE_STOP: RadiusAcctStop,
            STATUS_TYPE_UPDATE: RadiusAcctUpdate,
            STATUS_TYPE_ACCT_ON: RadiusAcctOnoff,
            STATUS_TYPE_ACCT_OFF: RadiusAcctOnoff
        }

    def find_nas(self, ip_addr):
        def fetch_result():
            table = models.TrBas.__table__
            with self.db_engine.begin() as conn:
                return conn.execute(
                    table.select().where(table.c.ip_addr == ip_addr)).first()

        return self.mcache.aget(bas_cache_key(ip_addr),
                                fetch_result,
                                expire=600)

    def do_stat(self, code, status_type=0):
        try:
            stat_msg = ['acct_drop']
            if code in (4, 5):
                stat_msg = []
                if code == packet.AccountingRequest:
                    stat_msg.append('acct_req')
                elif code == packet.AccountingResponse:
                    stat_msg.append('acct_resp')

                if status_type == 1:
                    stat_msg.append('acct_start')
                elif status_type == 2:
                    stat_msg.append('acct_stop')
                elif status_type == 3:
                    stat_msg.append('acct_update')
                elif status_type == 7:
                    stat_msg.append('acct_on')
                elif status_type == 8:
                    stat_msg.append('acct_off')
            self.stat_pusher.push(msgpack.packb(stat_msg))
        except:
            pass

    def process(self, message):
        datagram, host, port = msgpack.unpackb(message[0])
        self.processAcct(datagram, host, port)

    def createAcctPacket(self, **kwargs):
        vendor_id = 0
        if 'vendor_id' in kwargs:
            vendor_id = kwargs.pop('vendor_id')
        acct_message = message.AcctMessage(**kwargs)
        acct_message.vendor_id = vendor_id
        acct_message = mac_parse.process(acct_message)
        acct_message = vlan_parse.process(acct_message)
        return acct_message

    def processAcct(self, datagram, host, port):
        try:
            bas = self.find_nas(host)
            if not bas:
                raise PacketError(
                    '[Radiusd] :: Dropping packet from unknown host %s' % host)

            secret, vendor_id = bas['bas_secret'], bas['vendor_id']
            req = self.createAcctPacket(packet=datagram,
                                        dict=self.dict,
                                        secret=six.b(str(secret)),
                                        vendor_id=vendor_id)

            self.do_stat(req.code, req.get_acct_status_type())

            logger.info("[Radiusd] :: Received radius request: %s" %
                        (repr(req)))
            if self.config.system.debug:
                logger.debug(req.format_str())

            if req.code != packet.AccountingRequest:
                raise PacketError(
                    'non-AccountingRequest packet on authentication socket')

            if not req.VerifyAcctRequest():
                raise PacketError('VerifyAcctRequest error')

            reply = req.CreateReply()
            self.pusher.push(msgpack.packb([reply.ReplyPacket(), host, port]))
            self.do_stat(reply.code)
            logger.info("[Radiusd] :: Send radius response: %s" % repr(reply))
            if self.config.system.debug:
                logger.debug(reply.format_str())

            status_type = req.get_acct_status_type()
            if status_type in self.acct_class:
                acct_func = self.acct_class[status_type](
                    self.db_engine, self.mcache, None,
                    req.get_ticket()).acctounting
                reactor.callLater(0.1, acct_func)
            else:
                logger.error('status_type <%s> not support' % status_type)
        except Exception as err:
            self.do_stat(0)
            errstr = 'RadiusError:Dropping invalid acct packet from {0} {1},{2}'.format(
                host, port, utils.safeunicode(err))
            logger.error(errstr)
            import traceback
            traceback.print_exc()
Пример #9
0
class RADIUSAuthWorker(protocol.DatagramProtocol):
    def __init__(self, config, dbengine, radcache=None):
        self.config = config
        self.dict = dictionary.Dictionary(
            os.path.join(os.path.dirname(toughradius.__file__),
                         'dictionarys/dictionary'))
        self.db_engine = dbengine or get_engine(config)
        self.aes = utils.AESCipher(key=self.config.system.secret)
        self.mcache = radcache
        self.pusher = ZmqPushConnection(
            ZmqFactory(),
            ZmqEndpoint('connect', 'ipc:///tmp/radiusd-auth-result'))
        self.stat_pusher = ZmqPushConnection(
            ZmqFactory(), ZmqEndpoint('connect',
                                      'ipc:///tmp/radiusd-stat-task'))
        self.puller = ZmqPullConnection(
            ZmqFactory(),
            ZmqEndpoint('connect', 'ipc:///tmp/radiusd-auth-message'))
        self.puller.onPull = self.process
        reactor.listenUDP(0, self)
        logger.info("init auth worker pusher : %s " % (self.pusher))
        logger.info("init auth worker puller : %s " % (self.puller))
        logger.info("init auth stat pusher : %s " % (self.stat_pusher))

    def find_nas(self, ip_addr):
        def fetch_result():
            table = models.TrBas.__table__
            with self.db_engine.begin() as conn:
                return conn.execute(
                    table.select().where(table.c.ip_addr == ip_addr)).first()

        return self.mcache.aget(bas_cache_key(ip_addr),
                                fetch_result,
                                expire=600)

    def do_stat(self, code):
        try:
            stat_msg = []
            if code == packet.AccessRequest:
                stat_msg.append('auth_req')
            elif code == packet.AccessAccept:
                stat_msg.append('auth_accept')
            elif code == packet.AccessReject:
                stat_msg.append('auth_reject')
            else:
                stat_msg = ['auth_drop']
            self.stat_pusher.push(msgpack.packb(stat_msg))
        except:
            pass

    def process(self, message):
        datagram, host, port = msgpack.unpackb(message[0])
        reply = self.processAuth(datagram, host, port)
        if not reply:
            return
        logger.info("[Radiusd] :: Send radius response: %s" % repr(reply))
        if self.config.system.debug:
            logger.debug(reply.format_str())
        self.pusher.push(msgpack.packb([reply.ReplyPacket(), host, port]))
        # self.transport.write(reply.ReplyPacket(), (host,port))
        self.do_stat(reply.code)

    def createAuthPacket(self, **kwargs):
        vendor_id = kwargs.pop('vendor_id', 0)
        auth_message = message.AuthMessage(**kwargs)
        auth_message.vendor_id = vendor_id
        auth_message = mac_parse.process(auth_message)
        auth_message = vlan_parse.process(auth_message)
        return auth_message

    def processAuth(self, datagram, host, port):
        try:
            bas = self.find_nas(host)
            if not bas:
                raise PacketError(
                    '[Radiusd] :: Dropping packet from unknown host %s' % host)

            secret, vendor_id = bas['bas_secret'], bas['vendor_id']
            req = self.createAuthPacket(packet=datagram,
                                        dict=self.dict,
                                        secret=six.b(str(secret)),
                                        vendor_id=vendor_id)

            self.do_stat(req.code)

            logger.info("[Radiusd] :: Received radius request: %s" %
                        (repr(req)))
            if self.config.system.debug:
                logger.debug(req.format_str())

            if req.code != packet.AccessRequest:
                raise PacketError(
                    'non-AccessRequest packet on authentication socket')

            reply = req.CreateReply()
            reply.vendor_id = req.vendor_id

            aaa_request = dict(account_number=req.get_user_name(),
                               domain=req.get_domain(),
                               macaddr=req.client_mac,
                               nasaddr=req.get_nas_addr(),
                               vlanid1=req.vlanid1,
                               vlanid2=req.vlanid2)

            auth_resp = RadiusAuth(self.db_engine, self.mcache, self.aes,
                                   aaa_request).authorize()

            if auth_resp['code'] > 0:
                reply['Reply-Message'] = auth_resp['msg']
                reply.code = packet.AccessReject
                return reply

            if 'bypass' in auth_resp and int(auth_resp['bypass']) == 0:
                is_pwd_ok = True
            else:
                is_pwd_ok = req.is_valid_pwd(auth_resp.get('passwd'))

            if not is_pwd_ok:
                reply['Reply-Message'] = "password not match"
                reply.code = packet.AccessReject
                return reply
            else:
                if u"input_rate" in auth_resp and u"output_rate" in auth_resp:
                    reply = rate_process.process(
                        reply,
                        input_rate=auth_resp['input_rate'],
                        output_rate=auth_resp['output_rate'])

                attrs = auth_resp.get("attrs") or {}
                for attr_name in attrs:
                    try:
                        # todo: May have a type matching problem
                        reply.AddAttribute(utils.safestr(attr_name),
                                           attrs[attr_name])
                    except Exception as err:
                        errstr = "RadiusError:current radius cannot support attribute {0},{1}".format(
                            attr_name, utils.safestr(err.message))
                        logger.error(errstr)

                for attr, attr_val in req.resp_attrs.iteritems():
                    reply[attr] = attr_val

            reply['Reply-Message'] = 'success!'
            reply.code = packet.AccessAccept
            if not req.VerifyReply(reply):
                raise PacketError('VerifyReply error')
            return reply
        except Exception as err:
            self.do_stat(0)
            errstr = 'RadiusError:Dropping invalid auth packet from {0} {1},{2}'.format(
                host, port, utils.safeunicode(err))
            logger.error(errstr)
            import traceback
            traceback.print_exc()
Пример #10
0
class RADIUSAcctWorker(TraceMix):
    """ 记账子进程, 处理计费逻辑, 把结果推送个 radius 协议处理主进程, 
    记账是异步处理的, 即每次收到记账消息时, 立即推送响应, 然后在后台异步处理计费逻辑。
    """

    def __init__(self, config, dbengine, radcache = None):
        self.config = config
        self.load_plugins(load_types=['radius_acct_req'])
        self.db_engine = dbengine or get_engine(config)
        self.mcache = radcache
        self.dict = dictionary.Dictionary(os.path.join(os.path.dirname(taurusxradius.__file__), 'dictionarys/dictionary'))
        self.stat_pusher = ZmqPushConnection(ZmqFactory(), ZmqEndpoint('connect', config.mqproxy['task_connect']))
        self.pusher = ZmqPushConnection(ZmqFactory(), ZmqEndpoint('connect', config.mqproxy['acct_result']))
        self.puller = ZmqPullConnection(ZmqFactory(), ZmqEndpoint('connect', config.mqproxy['acct_message']))
        self.puller.onPull = self.process
        self.acct_class = {STATUS_TYPE_START: RadiusAcctStart,
         STATUS_TYPE_STOP: RadiusAcctStop,
         STATUS_TYPE_UPDATE: RadiusAcctUpdate,
         STATUS_TYPE_ACCT_ON: RadiusAcctOnoff,
         STATUS_TYPE_ACCT_OFF: RadiusAcctOnoff}
        logger.info('radius acct worker %s start' % os.getpid())
        logger.info('init acct worker pusher : %s ' % self.pusher)
        logger.info('init acct worker puller : %s ' % self.puller)
        logger.info('init acct stat pusher : %s ' % self.stat_pusher)

    def do_stat(self, code, status_type = 0, req = None):
        try:
            stat_msg = {'statattrs': ['acct_drop'],
             'raddata': {}}
            if code in (4, 5):
                stat_msg['statattrs'] = []
                if code == packet.AccountingRequest:
                    stat_msg['statattrs'].append('acct_req')
                elif code == packet.AccountingResponse:
                    stat_msg['statattrs'].append('acct_resp')
                if status_type == 1:
                    stat_msg['statattrs'].append('acct_start')
                elif status_type == 2:
                    stat_msg['statattrs'].append('acct_stop')
                elif status_type == 3:
                    stat_msg['statattrs'].append('acct_update')
                    stat_msg['raddata']['input_total'] = req.get_input_total()
                    stat_msg['raddata']['output_total'] = req.get_output_total()
                elif status_type == 7:
                    stat_msg['statattrs'].append('acct_on')
                elif status_type == 8:
                    stat_msg['statattrs'].append('acct_off')
            self.stat_pusher.push(msgpack.packb(stat_msg))
        except:
            pass

    def process(self, message):
        datagram, host, port = msgpack.unpackb(message[0])
        reply = self.processAcct(datagram, host, port)
        if reply is None:
            return
        else:
            self.pusher.push(msgpack.packb([reply.ReplyPacket(), host, port]))
            return

    def createAcctPacket(self, **kwargs):
        vendor_id = kwargs.pop('vendor_id', 0)
        acct_message = message.AcctMessage(**kwargs)
        acct_message.vendor_id = vendor_id
        for plugin in self.acct_req_plugins:
            acct_message = plugin.plugin_func(acct_message)

        return acct_message

    def processAcct(self, datagram, host, port):
        try:
            req = self.createAcctPacket(packet=datagram, dict=self.dict, secret=six.b(''), vendor_id=0)
            bas = self.find_nas(host) or self.find_nas_byid(req.get_nas_id())
            if not bas:
                raise PacketError(u'Unauthorized access Nas %s' % host)
            secret, vendor_id = bas['bas_secret'], bas['vendor_id']
            req.secret = six.b(str(secret))
            req.vendor_id = vendor_id
            self.log_trace(host, port, req)
            self.do_stat(req.code, req.get_acct_status_type(), req=req)
            if req.code != packet.AccountingRequest:
                errstr = u'Invalid accounting request code=%s' % req.code
                logger.error(errstr, tag='radius_acct_drop', trace='radius', username=req.get_user_name())
                return
            if not req.VerifyAcctRequest():
                errstr = u'Check accounting response failed, please check  shared secret'
                logger.error(errstr, tag='radius_acct_drop', trace='radius', username=req.get_user_name())
                return
            status_type = req.get_acct_status_type()
            if status_type in self.acct_class:
                ticket = req.get_ticket()
                ticket['nas_addr'] = host
                acct_func = self.acct_class[status_type](self.db_engine, self.mcache, None, ticket).acctounting
                reactor.callLater(0.05, acct_func)
            else:
                errstr = u'accounting type <%s> not supported' % status_type
                logger.error(errstr, tag='radius_acct_drop', trace='radius', username=req.get_user_name())
                return
            reply = req.CreateReply()
            reactor.callLater(0.05, self.log_trace, host, port, req, reply)
            reactor.callLater(0.05, self.do_stat, reply.code)
            return reply
        except Exception as err:
            self.do_stat(0)
            logger.exception(err, tag='radius_acct_drop')

        return
Пример #11
0
class RADIUSAuthWorker(TraceMix):
    """ \xe8\xae\xa4\xe8\xaf\x81\xe5\xad\x90\xe8\xbf\x9b\xe7\xa8\x8b, \xe5\xa4\x84\xe7\x90\x86\xe8\xae\xa4\xe8\xaf\x81\xe6\x8e\x88\xe6\x9d\x83\xe9\x80\xbb\xe8\xbe\x91, \xe6\x8a\x8a\xe7\xbb\x93\xe6\x9e\x9c\xe6\x8e\xa8\xe9\x80\x81\xe4\xb8\xaa radius \xe5\x8d\x8f\xe8\xae\xae\xe5\xa4\x84\xe7\x90\x86\xe4\xb8\xbb\xe8\xbf\x9b\xe7\xa8\x8b
    """

    def __init__(self, config, dbengine, radcache = None):
        self.config = config
        self.load_plugins(load_types=['radius_auth_req', 'radius_accept'])
        self.dict = dictionary.Dictionary(os.path.join(os.path.dirname(taurusxradius.__file__), 'dictionarys/dictionary'))
        self.db_engine = dbengine or get_engine(config)
        self.aes = utils.AESCipher(key=self.config.system.secret)
        self.mcache = radcache
        self.reject_debug = int(self.get_param_value('radius_reject_debug', 0)) == 1
        self.pusher = ZmqPushConnection(ZmqFactory(), ZmqEndpoint('connect', config.mqproxy['auth_result']))
        self.stat_pusher = ZmqPushConnection(ZmqFactory(), ZmqEndpoint('connect', config.mqproxy['task_connect']))
        self.puller = ZmqPullConnection(ZmqFactory(), ZmqEndpoint('connect', config.mqproxy['auth_message']))
        self.puller.onPull = self.process
        logger.info('radius auth worker %s start' % os.getpid())
        logger.info('init auth worker pusher : %s ' % self.pusher)
        logger.info('init auth worker puller : %s ' % self.puller)
        logger.info('init auth stat pusher : %s ' % self.stat_pusher)
        self.license_ulimit = 50000

    def get_account_bind_nas_ipaddrs(self, account_number, ip_addr):
        """ 获取用户区域绑定的 bas IP地址信息
        """

        def fetch_result():
            with self.db_engine.begin() as conn:
                tbas = models.TrBas.__table__
                tcus = models.TrCustomer.__table__
                tuser = models.TrAccount.__table__
                tbn = models.TrBasNode.__table__
                with self.db_engine.begin() as conn:
                    stmt = tbas.select().with_only_columns([tbas.c.ip_addr]).where(tcus.c.customer_id == tuser.c.customer_id).where(tcus.c.node_id == tbn.c.node_id).where(tbn.c.bas_id == tbas.c.id).where(tuser.c.account_number == account_number).where(tbas.c.ip_addr == ip_addr)
                    return [ v for v in conn.execute(stmt) ]

        return self.mcache.aget(account_bind_basip_key(account_number), fetch_result, expire=3600)

    def get_account_bind_nas_ids(self, account_number, nas_id):
        """ 获取用户区域绑定的 bas 标识信息
        """

        def fetch_result():
            with self.db_engine.begin() as conn:
                tbas = models.TrBas.__table__
                tcus = models.TrCustomer.__table__
                tuser = models.TrAccount.__table__
                tbn = models.TrBasNode.__table__
                with self.db_engine.begin() as conn:
                    stmt = tbas.select().with_only_columns([tbas.c.nas_id]).where(tcus.c.customer_id == tuser.c.customer_id).where(tcus.c.node_id == tbn.c.node_id).where(tbn.c.bas_id == tbas.c.id).where(tuser.c.account_number == account_number).where(tbas.c.nas_id == nas_id)
                    return [ v for v in conn.execute(stmt) ]

        return self.mcache.aget(account_bind_basid_key(account_number), fetch_result, expire=3600)

    def do_stat(self, code):
        try:
            stat_msg = {'statattrs': [],
             'raddata': {}}
            if code == packet.AccessRequest:
                stat_msg['statattrs'].append('auth_req')
            elif code == packet.AccessAccept:
                stat_msg['statattrs'].append('auth_accept')
            elif code == packet.AccessReject:
                stat_msg['statattrs'].append('auth_reject')
            else:
                stat_msg['statattrs'] = ['auth_drop']
            self.stat_pusher.push(msgpack.packb(stat_msg))
        except:
            pass

    def process(self, message):
        table = models.TrOnline.__table__
        with self.db_engine.begin() as conn:
            count = conn.execute(table.count()).scalar()
            if count >= self.license_ulimit:
                logger.error(u'Online user empowerment has been limited <%s>' % self.license_ulimit)
                return
        datagram, host, port = msgpack.unpackb(message[0])
        reply = self.processAuth(datagram, host, port)
        if reply is None:
            return
        else:
            if reply.code == packet.AccessReject and 'Reply-Message' in reply and int(self.get_param_value('radius_reject_message', 0)) == 0:
                del reply['Reply-Message']
            self.pusher.push(msgpack.packb([reply.ReplyPacket(), host, port]))
            self.do_stat(reply.code)
            return

    def createAuthPacket(self, **kwargs):
        vendor_id = kwargs.pop('vendor_id', 0)
        auth_message = message.AuthMessage(**kwargs)
        auth_message.vendor_id = vendor_id
        for plugin in self.auth_req_plugins:
            auth_message = plugin.plugin_func(auth_message)

        return auth_message

    def freeReply(self, req):
        """ 用户免认证响应,下发默认策略
        """
        reply = req.CreateReply()
        reply.vendor_id = req.vendor_id
        reply['Reply-Message'] = u'User:%s FreeAuth Success' % req.get_user_name()
        reply.code = packet.AccessAccept
        reply_attrs = {'attrs': {}}
        reply_attrs['input_rate'] = int(self.get_param_value('radius_free_input_rate', 1048576))
        reply_attrs['output_rate'] = int(self.get_param_value('radius_free_output_rate', 4194304))
        reply_attrs['rate_code'] = self.get_param_value('radius_free_rate_code', 'freerate')
        reply_attrs['domain'] = self.get_param_value('radius_free_domain', 'freedomain')
        reply_attrs['attrs']['Session-Timeout'] = int(self.get_param_value('radius_max_session_timeout', 86400))
        for plugin in self.auth_accept_plugins:
            reply = plugin.plugin_func(reply, reply_attrs)

        return reply

    def rejectReply(self, req, errmsg = ''):
        reply = req.CreateReply()
        reply.vendor_id = req.vendor_id
        reply['Reply-Message'] = errmsg
        reply.code = packet.AccessReject
        return reply

    def processAuth(self, datagram, host, port):
        try:
            req = self.createAuthPacket(packet=datagram, dict=self.dict, secret=six.b(''), vendor_id=0)
            nas_id = req.get_nas_id()
            bastype = 'ipaddr'
            bas = self.find_nas(host)
            if not bas:
                bastype = 'nasid'
                bas = self.find_nas_byid(nas_id)
                if not bas:
                    raise PacketError(u'Unauthorized Access Nas %s' % host)
            secret, vendor_id = bas['bas_secret'], bas['vendor_id']
            req.secret = six.b(str(secret))
            req.vendor_id = vendor_id
            username = req.get_user_name()
            bypass = int(self.get_param_value('radius_bypass', 1))
            if req.code != packet.AccessRequest:
                errstr = u'Illegal Auth request, code=%s' % req.code
                logger.error(errstr, tag='radius_auth_drop', trace='radius', username=username)
                return
            self.log_trace(host, port, req)
            self.do_stat(req.code)
            if bypass == 2:
                reply = self.freeReply(req)
                self.log_trace(host, port, req, reply)
                return reply
            if not self.user_exists(username):
                errmsg = u'Auth Error:user:%s not exists' % utils.safeunicode(username)
                reply = self.rejectReply(req, errmsg)
                self.log_trace(host, port, req, reply)
                return reply
            if bastype == 'ipaddr':
                bind_nasip_list = self.get_account_bind_nas_ipaddrs(username, host)
                if not bind_nasip_list:
                    errmsg = u'Nas:%s not bind user:%s area' % (host, username)
                    reply = self.rejectReply(req, errmsg)
                    self.log_trace(host, port, req, reply)
                    return reply
            elif bastype == 'nasid':
                bind_nasid_list = self.get_account_bind_nas_ids(username, nas_id)
                if not bind_nasid_list:
                    errmsg = u'Nas:%s not bind user:%s area' % (nas_id, username)
                    reply = self.rejectReply(req, errmsg)
                    self.log_trace(host, port, req, reply)
                    return reply
            aaa_request = dict(account_number=username, domain=req.get_domain(), macaddr=req.client_mac, nasaddr=req.get_nas_addr(), vlanid1=req.vlanid1, vlanid2=req.vlanid2, bypass=bypass, radreq=req)
            auth_resp = RadiusAuth(self.db_engine, self.mcache, self.aes, aaa_request).authorize()
            if auth_resp['code'] > 0:
                reply = self.rejectReply(req, auth_resp['msg'])
                self.log_trace(host, port, req, reply)
                return reply
            reply = req.CreateReply()
            reply.code = packet.AccessAccept
            reply.vendor_id = req.vendor_id
            extmsg = u'domain=%s;' % auth_resp['domain'] if 'domain' in auth_resp else ''
            extmsg += u'rate_policy=%s;' % auth_resp['rate_code'] if 'rate_code' in auth_resp else ''
            reply['Reply-Message'] = u'User:%s Auth success; %s' % (username, extmsg)
            for plugin in self.auth_accept_plugins:
                reply = plugin.plugin_func(reply, auth_resp)

            if not req.VerifyReply(reply):
                errstr = u'User:%s Auth message error, Please check share secret' % username
                logger.error(errstr, tag='radius_auth_drop', trace='radius', username=username)
                return
            self.log_trace(host, port, req, reply)
            return reply
        except Exception as err:
            self.do_stat(0)
            logger.exception(err, tag='radius_auth_error')