class BenchmarkWorker: def __init__(self,server,port,secret,requests,concurrency,username,password, verb=False,timeout=600,rate=1000): logname = "/tmp/trbctl-worker-{}.log".format(os.environ.get("LOGID",0)) log.startLogging(open(logname,'w')) self.timeout = timeout self.pusher = ZmqPushConnection(ZmqFactory(), ZmqEndpoint('connect', 'ipc:///tmp/toughbt-message')) self.pusher.push("write worker %s log into %s" % (os.getpid(),logname)) log.msg("init BenchmarkWorker pusher : %s " % repr(self.pusher)) # define client cycle list raddict = dictionary.Dictionary(os.path.join(os.path.dirname(__file__),"dictionary")) new_cli = lambda : RadAuthClient(str(secret), raddict, server,port=port,debug=verb,stat_push=self.pusher) clis = itertools.cycle([new_cli() for c in range(concurrency)]) # send radius message send = lambda:next(clis).sendAuth(**{'User-Name' : username,'User-Password':password}) send_rate = 1.0/rate send_delay = send_rate for i in xrange(requests): reactor.callLater(send_delay,send) send_delay += send_rate reactor.callLater(self.timeout,self.on_timeout) def on_timeout(self): self.pusher.push("logger: BenchmarkWorker timeout, running times: %s" % self.timeout) reactor.stop()
class ZeroMQDelegatorService(Service): """ This is an outbound PUSH connection to the web API ZeroMQRepeaterService that allows a worker to delegate any sub-links it finds (instead of taking a detour to crawl them on their own). """ def __init__(self): self.conn = None def startService(self): factory = ZmqFactory() log.msg("Delegator connecting to repeater: %s" % ZMQ_REPEATER) endpoint = ZmqEndpoint('connect', ZMQ_REPEATER) self.conn = ZmqPushConnection(factory, endpoint) def send_message(self, message): """ Matches the signature of ZeroMQBroadcastService so we can use them interchangably in the job queue code. :param str message: A JSON crawler message. """ log.msg("Delegating job: %s" % message) self.conn.push(message)
class TwistedZmqClient(object): def __init__(self, service): zf = ZmqFactory() e = ZmqEndpoint('connect', 'tcp://%s:%s' % (service.host, service.port)) self.conn = ZmqPushConnection(zf, e) def send(self, msg): self.conn.push(msg)
class TwistedZmqClient(object): def __init__(self, service): zf = ZmqFactory() e = ZmqEndpoint('connect', 'tcp://%s:%s' % (service.host, service.port)) self.conn = ZmqPushConnection(zf, e) def send(self, clientid, levellist): self.conn.push(json.dumps({'clientid': clientid, 'levellist': levellist}))
class ZeroMQBroadcastService(Service): """ This is used by the HTTP API to hand crawl jobs off to the crawler pool. """ def __init__(self): self.conn = None def startService(self): factory = ZmqFactory() bind_point = 'tcp://0.0.0.0:8050' log.msg("Broadcaster binding on: %s" % bind_point) endpoint = ZmqEndpoint('bind', bind_point) self.conn = ZmqPushConnection(factory, endpoint) def send_message(self, message): log.msg("Sent crawl announcement: %s" % message) self.conn.push(message)
class RADIUSAuthWorker(object): def __init__(self, config, dbengine): self.config = config self.dict = dictionary.Dictionary( os.path.join(os.path.dirname(toughradius.__file__), 'dictionarys/dictionary')) self.db_engine = dbengine or get_engine(config) self.aes = utils.AESCipher(key=self.config.system.secret) self.mcache = mcache.Mcache() self.pusher = ZmqPushConnection(ZmqFactory(), ZmqEndpoint('connect', 'ipc:///tmp/radiusd-auth-result')) self.stat_pusher = ZmqPushConnection(ZmqFactory(), ZmqEndpoint('connect', 'ipc:///tmp/radiusd-stat-task')) self.puller = ZmqPullConnection(ZmqFactory(), ZmqEndpoint('connect', 'ipc:///tmp/radiusd-auth-message')) self.puller.onPull = self.process logger.info("init auth worker pusher : %s " % (self.pusher)) logger.info("init auth worker puller : %s " % (self.puller)) logger.info("init auth stat pusher : %s " % (self.stat_pusher)) def find_nas(self,ip_addr): def fetch_result(): table = models.TrBas.__table__ with self.db_engine.begin() as conn: return conn.execute(table.select().where(table.c.ip_addr==ip_addr)).first() return self.mcache.aget(bas_cache_key(ip_addr),fetch_result, expire=600) def do_stat(self,code): try: stat_msg = [] if code == packet.AccessRequest: stat_msg.append('auth_req') elif code == packet.AccessAccept: stat_msg.append('auth_accept') elif code == packet.AccessReject: stat_msg.append('auth_reject') else: stat_msg = ['auth_drop'] self.stat_pusher.push(msgpack.packb(stat_msg)) except: pass def process(self, message): datagram, host, port = msgpack.unpackb(message[0]) reply = self.processAuth(datagram, host, port) if not reply: return self.do_stat(reply.code) logger.info("[Radiusd] :: Send radius response: %s" % repr(reply)) if self.config.system.debug: logger.debug(reply.format_str()) self.pusher.push(msgpack.packb([reply.ReplyPacket(),host,port])) def createAuthPacket(self, **kwargs): vendor_id = kwargs.pop('vendor_id',0) auth_message = message.AuthMessage(**kwargs) auth_message.vendor_id = vendor_id auth_message = mac_parse.process(auth_message) auth_message = vlan_parse.process(auth_message) return auth_message def processAuth(self, datagram, host, port): try: bas = self.find_nas(host) if not bas: raise PacketError('[Radiusd] :: Dropping packet from unknown host %s' % host) secret, vendor_id = bas['bas_secret'], bas['vendor_id'] req = self.createAuthPacket(packet=datagram, dict=self.dict, secret=six.b(str(secret)),vendor_id=vendor_id) self.do_stat(req.code) logger.info("[Radiusd] :: Received radius request: %s" % (repr(req))) if self.config.system.debug: logger.debug(req.format_str()) if req.code != packet.AccessRequest: raise PacketError('non-AccessRequest packet on authentication socket') reply = req.CreateReply() reply.vendor_id = req.vendor_id aaa_request = dict( account_number=req.get_user_name(), domain=req.get_domain(), macaddr=req.client_mac, nasaddr=req.get_nas_addr(), vlanid1=req.vlanid1, vlanid2=req.vlanid2 ) auth_resp = RadiusAuth(self.db_engine,self.mcache,self.aes,aaa_request).authorize() if auth_resp['code'] > 0: reply['Reply-Message'] = auth_resp['msg'] reply.code = packet.AccessReject return reply if 'bypass' in auth_resp and int(auth_resp['bypass']) == 0: is_pwd_ok = True else: is_pwd_ok = req.is_valid_pwd(auth_resp.get('passwd')) if not is_pwd_ok: reply['Reply-Message'] = "password not match" reply.code = packet.AccessReject return reply else: if u"input_rate" in auth_resp and u"output_rate" in auth_resp: reply = rate_process.process( reply, input_rate=auth_resp['input_rate'], output_rate=auth_resp['output_rate']) attrs = auth_resp.get("attrs") or {} for attr_name in attrs: try: # todo: May have a type matching problem reply.AddAttribute(utils.safestr(attr_name), attrs[attr_name]) except Exception as err: errstr = "RadiusError:current radius cannot support attribute {0},{1}".format( attr_name,utils.safestr(err.message)) logger.error(errstr) for attr, attr_val in req.resp_attrs.iteritems(): reply[attr] = attr_val reply['Reply-Message'] = 'success!' reply.code = packet.AccessAccept if not req.VerifyReply(reply): raise PacketError('VerifyReply error') return reply except Exception as err: self.do_stat(0) errstr = 'RadiusError:Dropping invalid auth packet from {0} {1},{2}'.format( host, port, utils.safeunicode(err)) logger.error(errstr) import traceback traceback.print_exc()
class RADIUSAcctWorker(object): def __init__(self, config, dbengine): self.config = config self.dict = dictionary.Dictionary( os.path.join(os.path.dirname(toughradius.__file__), 'dictionarys/dictionary')) self.db_engine = dbengine or get_engine(config) self.mcache = mcache.Mcache() self.pusher = ZmqPushConnection(ZmqFactory(), ZmqEndpoint('connect', 'ipc:///tmp/radiusd-acct-result')) self.stat_pusher = ZmqPushConnection(ZmqFactory(), ZmqEndpoint('connect', 'ipc:///tmp/radiusd-stat-task')) self.puller = ZmqPullConnection(ZmqFactory(), ZmqEndpoint('connect', 'ipc:///tmp/radiusd-acct-message')) self.puller.onPull = self.process logger.info("init acct worker pusher : %s " % (self.pusher)) logger.info("init acct worker puller : %s " % (self.puller)) logger.info("init auth stat pusher : %s " % (self.stat_pusher)) self.acct_class = { STATUS_TYPE_START: RadiusAcctStart, STATUS_TYPE_STOP: RadiusAcctStop, STATUS_TYPE_UPDATE: RadiusAcctUpdate, STATUS_TYPE_ACCT_ON: RadiusAcctOnoff, STATUS_TYPE_ACCT_OFF: RadiusAcctOnoff } def find_nas(self,ip_addr): def fetch_result(): table = models.TrBas.__table__ with self.db_engine.begin() as conn: return conn.execute(table.select().where(table.c.ip_addr==ip_addr)).first() return self.mcache.aget(bas_cache_key(ip_addr),fetch_result, expire=600) def do_stat(self,code, status_type=0): try: stat_msg = ['acct_drop'] if code in (4,5): stat_msg = [] if code == packet.AccountingRequest: stat_msg.append('acct_req') elif code == packet.AccountingResponse: stat_msg.append('acct_resp') if status_type == 1: stat_msg.append('acct_start') elif status_type == 2: stat_msg.append('acct_stop') elif status_type == 3: stat_msg.append('acct_update') elif status_type == 7: stat_msg.append('acct_on') elif status_type == 8: stat_msg.append('acct_off') self.stat_pusher.push(msgpack.packb(stat_msg)) except: pass def process(self, message): datagram, host, port = msgpack.unpackb(message[0]) self.processAcct(datagram, host, port) def createAcctPacket(self, **kwargs): vendor_id = 0 if 'vendor_id' in kwargs: vendor_id = kwargs.pop('vendor_id') acct_message = message.AcctMessage(**kwargs) acct_message.vendor_id = vendor_id acct_message = mac_parse.process(acct_message) acct_message = vlan_parse.process(acct_message) return acct_message def processAcct(self, datagram, host, port): try: bas = self.find_nas(host) if not bas: raise PacketError('[Radiusd] :: Dropping packet from unknown host %s' % host) secret, vendor_id = bas['bas_secret'], bas['vendor_id'] req = self.createAcctPacket(packet=datagram, dict=self.dict, secret=six.b(str(secret)),vendor_id=vendor_id) self.do_stat(req.code, req.get_acct_status_type()) logger.info("[Radiusd] :: Received radius request: %s" % (repr(req))) if self.config.system.debug: logger.debug(req.format_str()) if req.code != packet.AccountingRequest: raise PacketError('non-AccountingRequest packet on authentication socket') if not req.VerifyAcctRequest(): raise PacketError('VerifyAcctRequest error') reply = req.CreateReply() self.pusher.push(msgpack.packb([reply.ReplyPacket(),host,port])) self.do_stat(reply.code) logger.info("[Radiusd] :: Send radius response: %s" % repr(reply)) if self.config.system.debug: logger.debug(reply.format_str()) status_type = req.get_acct_status_type() if status_type in self.acct_class: acct_func = self.acct_class[status_type]( self.db_engine,self.mcache,None,req.get_ticket()).acctounting reactor.callLater(0.1,acct_func) else: logger.error('status_type <%s> not support' % status_type) except Exception as err: self.do_stat(0) errstr = 'RadiusError:Dropping invalid acct packet from {0} {1},{2}'.format( host, port, utils.safeunicode(err)) logger.error(errstr) import traceback traceback.print_exc()
class RADIUSAcctWorker(object): def __init__(self, config, dbengine, radcache=None): self.config = config self.dict = dictionary.Dictionary( os.path.join(os.path.dirname(toughradius.__file__), 'dictionarys/dictionary')) self.db_engine = dbengine or get_engine(config) self.mcache = radcache self.pusher = ZmqPushConnection( ZmqFactory(), ZmqEndpoint('connect', 'ipc:///tmp/radiusd-acct-result')) self.stat_pusher = ZmqPushConnection( ZmqFactory(), ZmqEndpoint('connect', 'ipc:///tmp/radiusd-stat-task')) self.puller = ZmqPullConnection( ZmqFactory(), ZmqEndpoint('connect', 'ipc:///tmp/radiusd-acct-message')) self.puller.onPull = self.process logger.info("init acct worker pusher : %s " % (self.pusher)) logger.info("init acct worker puller : %s " % (self.puller)) logger.info("init auth stat pusher : %s " % (self.stat_pusher)) self.acct_class = { STATUS_TYPE_START: RadiusAcctStart, STATUS_TYPE_STOP: RadiusAcctStop, STATUS_TYPE_UPDATE: RadiusAcctUpdate, STATUS_TYPE_ACCT_ON: RadiusAcctOnoff, STATUS_TYPE_ACCT_OFF: RadiusAcctOnoff } def find_nas(self, ip_addr): def fetch_result(): table = models.TrBas.__table__ with self.db_engine.begin() as conn: return conn.execute( table.select().where(table.c.ip_addr == ip_addr)).first() return self.mcache.aget(bas_cache_key(ip_addr), fetch_result, expire=600) def do_stat(self, code, status_type=0): try: stat_msg = ['acct_drop'] if code in (4, 5): stat_msg = [] if code == packet.AccountingRequest: stat_msg.append('acct_req') elif code == packet.AccountingResponse: stat_msg.append('acct_resp') if status_type == 1: stat_msg.append('acct_start') elif status_type == 2: stat_msg.append('acct_stop') elif status_type == 3: stat_msg.append('acct_update') elif status_type == 7: stat_msg.append('acct_on') elif status_type == 8: stat_msg.append('acct_off') self.stat_pusher.push(msgpack.packb(stat_msg)) except: pass def process(self, message): datagram, host, port = msgpack.unpackb(message[0]) self.processAcct(datagram, host, port) def createAcctPacket(self, **kwargs): vendor_id = 0 if 'vendor_id' in kwargs: vendor_id = kwargs.pop('vendor_id') acct_message = message.AcctMessage(**kwargs) acct_message.vendor_id = vendor_id acct_message = mac_parse.process(acct_message) acct_message = vlan_parse.process(acct_message) return acct_message def processAcct(self, datagram, host, port): try: bas = self.find_nas(host) if not bas: raise PacketError( '[Radiusd] :: Dropping packet from unknown host %s' % host) secret, vendor_id = bas['bas_secret'], bas['vendor_id'] req = self.createAcctPacket(packet=datagram, dict=self.dict, secret=six.b(str(secret)), vendor_id=vendor_id) self.do_stat(req.code, req.get_acct_status_type()) logger.info("[Radiusd] :: Received radius request: %s" % (repr(req))) if self.config.system.debug: logger.debug(req.format_str()) if req.code != packet.AccountingRequest: raise PacketError( 'non-AccountingRequest packet on authentication socket') if not req.VerifyAcctRequest(): raise PacketError('VerifyAcctRequest error') reply = req.CreateReply() self.pusher.push(msgpack.packb([reply.ReplyPacket(), host, port])) self.do_stat(reply.code) logger.info("[Radiusd] :: Send radius response: %s" % repr(reply)) if self.config.system.debug: logger.debug(reply.format_str()) status_type = req.get_acct_status_type() if status_type in self.acct_class: acct_func = self.acct_class[status_type]( self.db_engine, self.mcache, None, req.get_ticket()).acctounting reactor.callLater(0.1, acct_func) else: logger.error('status_type <%s> not support' % status_type) except Exception as err: self.do_stat(0) errstr = 'RadiusError:Dropping invalid acct packet from {0} {1},{2}'.format( host, port, utils.safeunicode(err)) logger.error(errstr) import traceback traceback.print_exc()
class RADIUSAuthWorker(protocol.DatagramProtocol): def __init__(self, config, dbengine, radcache=None): self.config = config self.dict = dictionary.Dictionary( os.path.join(os.path.dirname(toughradius.__file__), 'dictionarys/dictionary')) self.db_engine = dbengine or get_engine(config) self.aes = utils.AESCipher(key=self.config.system.secret) self.mcache = radcache self.pusher = ZmqPushConnection( ZmqFactory(), ZmqEndpoint('connect', 'ipc:///tmp/radiusd-auth-result')) self.stat_pusher = ZmqPushConnection( ZmqFactory(), ZmqEndpoint('connect', 'ipc:///tmp/radiusd-stat-task')) self.puller = ZmqPullConnection( ZmqFactory(), ZmqEndpoint('connect', 'ipc:///tmp/radiusd-auth-message')) self.puller.onPull = self.process reactor.listenUDP(0, self) logger.info("init auth worker pusher : %s " % (self.pusher)) logger.info("init auth worker puller : %s " % (self.puller)) logger.info("init auth stat pusher : %s " % (self.stat_pusher)) def find_nas(self, ip_addr): def fetch_result(): table = models.TrBas.__table__ with self.db_engine.begin() as conn: return conn.execute( table.select().where(table.c.ip_addr == ip_addr)).first() return self.mcache.aget(bas_cache_key(ip_addr), fetch_result, expire=600) def do_stat(self, code): try: stat_msg = [] if code == packet.AccessRequest: stat_msg.append('auth_req') elif code == packet.AccessAccept: stat_msg.append('auth_accept') elif code == packet.AccessReject: stat_msg.append('auth_reject') else: stat_msg = ['auth_drop'] self.stat_pusher.push(msgpack.packb(stat_msg)) except: pass def process(self, message): datagram, host, port = msgpack.unpackb(message[0]) reply = self.processAuth(datagram, host, port) if not reply: return logger.info("[Radiusd] :: Send radius response: %s" % repr(reply)) if self.config.system.debug: logger.debug(reply.format_str()) self.pusher.push(msgpack.packb([reply.ReplyPacket(), host, port])) # self.transport.write(reply.ReplyPacket(), (host,port)) self.do_stat(reply.code) def createAuthPacket(self, **kwargs): vendor_id = kwargs.pop('vendor_id', 0) auth_message = message.AuthMessage(**kwargs) auth_message.vendor_id = vendor_id auth_message = mac_parse.process(auth_message) auth_message = vlan_parse.process(auth_message) return auth_message def processAuth(self, datagram, host, port): try: bas = self.find_nas(host) if not bas: raise PacketError( '[Radiusd] :: Dropping packet from unknown host %s' % host) secret, vendor_id = bas['bas_secret'], bas['vendor_id'] req = self.createAuthPacket(packet=datagram, dict=self.dict, secret=six.b(str(secret)), vendor_id=vendor_id) self.do_stat(req.code) logger.info("[Radiusd] :: Received radius request: %s" % (repr(req))) if self.config.system.debug: logger.debug(req.format_str()) if req.code != packet.AccessRequest: raise PacketError( 'non-AccessRequest packet on authentication socket') reply = req.CreateReply() reply.vendor_id = req.vendor_id aaa_request = dict(account_number=req.get_user_name(), domain=req.get_domain(), macaddr=req.client_mac, nasaddr=req.get_nas_addr(), vlanid1=req.vlanid1, vlanid2=req.vlanid2) auth_resp = RadiusAuth(self.db_engine, self.mcache, self.aes, aaa_request).authorize() if auth_resp['code'] > 0: reply['Reply-Message'] = auth_resp['msg'] reply.code = packet.AccessReject return reply if 'bypass' in auth_resp and int(auth_resp['bypass']) == 0: is_pwd_ok = True else: is_pwd_ok = req.is_valid_pwd(auth_resp.get('passwd')) if not is_pwd_ok: reply['Reply-Message'] = "password not match" reply.code = packet.AccessReject return reply else: if u"input_rate" in auth_resp and u"output_rate" in auth_resp: reply = rate_process.process( reply, input_rate=auth_resp['input_rate'], output_rate=auth_resp['output_rate']) attrs = auth_resp.get("attrs") or {} for attr_name in attrs: try: # todo: May have a type matching problem reply.AddAttribute(utils.safestr(attr_name), attrs[attr_name]) except Exception as err: errstr = "RadiusError:current radius cannot support attribute {0},{1}".format( attr_name, utils.safestr(err.message)) logger.error(errstr) for attr, attr_val in req.resp_attrs.iteritems(): reply[attr] = attr_val reply['Reply-Message'] = 'success!' reply.code = packet.AccessAccept if not req.VerifyReply(reply): raise PacketError('VerifyReply error') return reply except Exception as err: self.do_stat(0) errstr = 'RadiusError:Dropping invalid auth packet from {0} {1},{2}'.format( host, port, utils.safeunicode(err)) logger.error(errstr) import traceback traceback.print_exc()
class RADIUSAcctWorker(TraceMix): """ 记账子进程, 处理计费逻辑, 把结果推送个 radius 协议处理主进程, 记账是异步处理的, 即每次收到记账消息时, 立即推送响应, 然后在后台异步处理计费逻辑。 """ def __init__(self, config, dbengine, radcache = None): self.config = config self.load_plugins(load_types=['radius_acct_req']) self.db_engine = dbengine or get_engine(config) self.mcache = radcache self.dict = dictionary.Dictionary(os.path.join(os.path.dirname(taurusxradius.__file__), 'dictionarys/dictionary')) self.stat_pusher = ZmqPushConnection(ZmqFactory(), ZmqEndpoint('connect', config.mqproxy['task_connect'])) self.pusher = ZmqPushConnection(ZmqFactory(), ZmqEndpoint('connect', config.mqproxy['acct_result'])) self.puller = ZmqPullConnection(ZmqFactory(), ZmqEndpoint('connect', config.mqproxy['acct_message'])) self.puller.onPull = self.process self.acct_class = {STATUS_TYPE_START: RadiusAcctStart, STATUS_TYPE_STOP: RadiusAcctStop, STATUS_TYPE_UPDATE: RadiusAcctUpdate, STATUS_TYPE_ACCT_ON: RadiusAcctOnoff, STATUS_TYPE_ACCT_OFF: RadiusAcctOnoff} logger.info('radius acct worker %s start' % os.getpid()) logger.info('init acct worker pusher : %s ' % self.pusher) logger.info('init acct worker puller : %s ' % self.puller) logger.info('init acct stat pusher : %s ' % self.stat_pusher) def do_stat(self, code, status_type = 0, req = None): try: stat_msg = {'statattrs': ['acct_drop'], 'raddata': {}} if code in (4, 5): stat_msg['statattrs'] = [] if code == packet.AccountingRequest: stat_msg['statattrs'].append('acct_req') elif code == packet.AccountingResponse: stat_msg['statattrs'].append('acct_resp') if status_type == 1: stat_msg['statattrs'].append('acct_start') elif status_type == 2: stat_msg['statattrs'].append('acct_stop') elif status_type == 3: stat_msg['statattrs'].append('acct_update') stat_msg['raddata']['input_total'] = req.get_input_total() stat_msg['raddata']['output_total'] = req.get_output_total() elif status_type == 7: stat_msg['statattrs'].append('acct_on') elif status_type == 8: stat_msg['statattrs'].append('acct_off') self.stat_pusher.push(msgpack.packb(stat_msg)) except: pass def process(self, message): datagram, host, port = msgpack.unpackb(message[0]) reply = self.processAcct(datagram, host, port) if reply is None: return else: self.pusher.push(msgpack.packb([reply.ReplyPacket(), host, port])) return def createAcctPacket(self, **kwargs): vendor_id = kwargs.pop('vendor_id', 0) acct_message = message.AcctMessage(**kwargs) acct_message.vendor_id = vendor_id for plugin in self.acct_req_plugins: acct_message = plugin.plugin_func(acct_message) return acct_message def processAcct(self, datagram, host, port): try: req = self.createAcctPacket(packet=datagram, dict=self.dict, secret=six.b(''), vendor_id=0) bas = self.find_nas(host) or self.find_nas_byid(req.get_nas_id()) if not bas: raise PacketError(u'Unauthorized access Nas %s' % host) secret, vendor_id = bas['bas_secret'], bas['vendor_id'] req.secret = six.b(str(secret)) req.vendor_id = vendor_id self.log_trace(host, port, req) self.do_stat(req.code, req.get_acct_status_type(), req=req) if req.code != packet.AccountingRequest: errstr = u'Invalid accounting request code=%s' % req.code logger.error(errstr, tag='radius_acct_drop', trace='radius', username=req.get_user_name()) return if not req.VerifyAcctRequest(): errstr = u'Check accounting response failed, please check shared secret' logger.error(errstr, tag='radius_acct_drop', trace='radius', username=req.get_user_name()) return status_type = req.get_acct_status_type() if status_type in self.acct_class: ticket = req.get_ticket() ticket['nas_addr'] = host acct_func = self.acct_class[status_type](self.db_engine, self.mcache, None, ticket).acctounting reactor.callLater(0.05, acct_func) else: errstr = u'accounting type <%s> not supported' % status_type logger.error(errstr, tag='radius_acct_drop', trace='radius', username=req.get_user_name()) return reply = req.CreateReply() reactor.callLater(0.05, self.log_trace, host, port, req, reply) reactor.callLater(0.05, self.do_stat, reply.code) return reply except Exception as err: self.do_stat(0) logger.exception(err, tag='radius_acct_drop') return
class RADIUSAuthWorker(TraceMix): """ \xe8\xae\xa4\xe8\xaf\x81\xe5\xad\x90\xe8\xbf\x9b\xe7\xa8\x8b, \xe5\xa4\x84\xe7\x90\x86\xe8\xae\xa4\xe8\xaf\x81\xe6\x8e\x88\xe6\x9d\x83\xe9\x80\xbb\xe8\xbe\x91, \xe6\x8a\x8a\xe7\xbb\x93\xe6\x9e\x9c\xe6\x8e\xa8\xe9\x80\x81\xe4\xb8\xaa radius \xe5\x8d\x8f\xe8\xae\xae\xe5\xa4\x84\xe7\x90\x86\xe4\xb8\xbb\xe8\xbf\x9b\xe7\xa8\x8b """ def __init__(self, config, dbengine, radcache = None): self.config = config self.load_plugins(load_types=['radius_auth_req', 'radius_accept']) self.dict = dictionary.Dictionary(os.path.join(os.path.dirname(taurusxradius.__file__), 'dictionarys/dictionary')) self.db_engine = dbengine or get_engine(config) self.aes = utils.AESCipher(key=self.config.system.secret) self.mcache = radcache self.reject_debug = int(self.get_param_value('radius_reject_debug', 0)) == 1 self.pusher = ZmqPushConnection(ZmqFactory(), ZmqEndpoint('connect', config.mqproxy['auth_result'])) self.stat_pusher = ZmqPushConnection(ZmqFactory(), ZmqEndpoint('connect', config.mqproxy['task_connect'])) self.puller = ZmqPullConnection(ZmqFactory(), ZmqEndpoint('connect', config.mqproxy['auth_message'])) self.puller.onPull = self.process logger.info('radius auth worker %s start' % os.getpid()) logger.info('init auth worker pusher : %s ' % self.pusher) logger.info('init auth worker puller : %s ' % self.puller) logger.info('init auth stat pusher : %s ' % self.stat_pusher) self.license_ulimit = 50000 def get_account_bind_nas_ipaddrs(self, account_number, ip_addr): """ 获取用户区域绑定的 bas IP地址信息 """ def fetch_result(): with self.db_engine.begin() as conn: tbas = models.TrBas.__table__ tcus = models.TrCustomer.__table__ tuser = models.TrAccount.__table__ tbn = models.TrBasNode.__table__ with self.db_engine.begin() as conn: stmt = tbas.select().with_only_columns([tbas.c.ip_addr]).where(tcus.c.customer_id == tuser.c.customer_id).where(tcus.c.node_id == tbn.c.node_id).where(tbn.c.bas_id == tbas.c.id).where(tuser.c.account_number == account_number).where(tbas.c.ip_addr == ip_addr) return [ v for v in conn.execute(stmt) ] return self.mcache.aget(account_bind_basip_key(account_number), fetch_result, expire=3600) def get_account_bind_nas_ids(self, account_number, nas_id): """ 获取用户区域绑定的 bas 标识信息 """ def fetch_result(): with self.db_engine.begin() as conn: tbas = models.TrBas.__table__ tcus = models.TrCustomer.__table__ tuser = models.TrAccount.__table__ tbn = models.TrBasNode.__table__ with self.db_engine.begin() as conn: stmt = tbas.select().with_only_columns([tbas.c.nas_id]).where(tcus.c.customer_id == tuser.c.customer_id).where(tcus.c.node_id == tbn.c.node_id).where(tbn.c.bas_id == tbas.c.id).where(tuser.c.account_number == account_number).where(tbas.c.nas_id == nas_id) return [ v for v in conn.execute(stmt) ] return self.mcache.aget(account_bind_basid_key(account_number), fetch_result, expire=3600) def do_stat(self, code): try: stat_msg = {'statattrs': [], 'raddata': {}} if code == packet.AccessRequest: stat_msg['statattrs'].append('auth_req') elif code == packet.AccessAccept: stat_msg['statattrs'].append('auth_accept') elif code == packet.AccessReject: stat_msg['statattrs'].append('auth_reject') else: stat_msg['statattrs'] = ['auth_drop'] self.stat_pusher.push(msgpack.packb(stat_msg)) except: pass def process(self, message): table = models.TrOnline.__table__ with self.db_engine.begin() as conn: count = conn.execute(table.count()).scalar() if count >= self.license_ulimit: logger.error(u'Online user empowerment has been limited <%s>' % self.license_ulimit) return datagram, host, port = msgpack.unpackb(message[0]) reply = self.processAuth(datagram, host, port) if reply is None: return else: if reply.code == packet.AccessReject and 'Reply-Message' in reply and int(self.get_param_value('radius_reject_message', 0)) == 0: del reply['Reply-Message'] self.pusher.push(msgpack.packb([reply.ReplyPacket(), host, port])) self.do_stat(reply.code) return def createAuthPacket(self, **kwargs): vendor_id = kwargs.pop('vendor_id', 0) auth_message = message.AuthMessage(**kwargs) auth_message.vendor_id = vendor_id for plugin in self.auth_req_plugins: auth_message = plugin.plugin_func(auth_message) return auth_message def freeReply(self, req): """ 用户免认证响应,下发默认策略 """ reply = req.CreateReply() reply.vendor_id = req.vendor_id reply['Reply-Message'] = u'User:%s FreeAuth Success' % req.get_user_name() reply.code = packet.AccessAccept reply_attrs = {'attrs': {}} reply_attrs['input_rate'] = int(self.get_param_value('radius_free_input_rate', 1048576)) reply_attrs['output_rate'] = int(self.get_param_value('radius_free_output_rate', 4194304)) reply_attrs['rate_code'] = self.get_param_value('radius_free_rate_code', 'freerate') reply_attrs['domain'] = self.get_param_value('radius_free_domain', 'freedomain') reply_attrs['attrs']['Session-Timeout'] = int(self.get_param_value('radius_max_session_timeout', 86400)) for plugin in self.auth_accept_plugins: reply = plugin.plugin_func(reply, reply_attrs) return reply def rejectReply(self, req, errmsg = ''): reply = req.CreateReply() reply.vendor_id = req.vendor_id reply['Reply-Message'] = errmsg reply.code = packet.AccessReject return reply def processAuth(self, datagram, host, port): try: req = self.createAuthPacket(packet=datagram, dict=self.dict, secret=six.b(''), vendor_id=0) nas_id = req.get_nas_id() bastype = 'ipaddr' bas = self.find_nas(host) if not bas: bastype = 'nasid' bas = self.find_nas_byid(nas_id) if not bas: raise PacketError(u'Unauthorized Access Nas %s' % host) secret, vendor_id = bas['bas_secret'], bas['vendor_id'] req.secret = six.b(str(secret)) req.vendor_id = vendor_id username = req.get_user_name() bypass = int(self.get_param_value('radius_bypass', 1)) if req.code != packet.AccessRequest: errstr = u'Illegal Auth request, code=%s' % req.code logger.error(errstr, tag='radius_auth_drop', trace='radius', username=username) return self.log_trace(host, port, req) self.do_stat(req.code) if bypass == 2: reply = self.freeReply(req) self.log_trace(host, port, req, reply) return reply if not self.user_exists(username): errmsg = u'Auth Error:user:%s not exists' % utils.safeunicode(username) reply = self.rejectReply(req, errmsg) self.log_trace(host, port, req, reply) return reply if bastype == 'ipaddr': bind_nasip_list = self.get_account_bind_nas_ipaddrs(username, host) if not bind_nasip_list: errmsg = u'Nas:%s not bind user:%s area' % (host, username) reply = self.rejectReply(req, errmsg) self.log_trace(host, port, req, reply) return reply elif bastype == 'nasid': bind_nasid_list = self.get_account_bind_nas_ids(username, nas_id) if not bind_nasid_list: errmsg = u'Nas:%s not bind user:%s area' % (nas_id, username) reply = self.rejectReply(req, errmsg) self.log_trace(host, port, req, reply) return reply aaa_request = dict(account_number=username, domain=req.get_domain(), macaddr=req.client_mac, nasaddr=req.get_nas_addr(), vlanid1=req.vlanid1, vlanid2=req.vlanid2, bypass=bypass, radreq=req) auth_resp = RadiusAuth(self.db_engine, self.mcache, self.aes, aaa_request).authorize() if auth_resp['code'] > 0: reply = self.rejectReply(req, auth_resp['msg']) self.log_trace(host, port, req, reply) return reply reply = req.CreateReply() reply.code = packet.AccessAccept reply.vendor_id = req.vendor_id extmsg = u'domain=%s;' % auth_resp['domain'] if 'domain' in auth_resp else '' extmsg += u'rate_policy=%s;' % auth_resp['rate_code'] if 'rate_code' in auth_resp else '' reply['Reply-Message'] = u'User:%s Auth success; %s' % (username, extmsg) for plugin in self.auth_accept_plugins: reply = plugin.plugin_func(reply, auth_resp) if not req.VerifyReply(reply): errstr = u'User:%s Auth message error, Please check share secret' % username logger.error(errstr, tag='radius_auth_drop', trace='radius', username=username) return self.log_trace(host, port, req, reply) return reply except Exception as err: self.do_stat(0) logger.exception(err, tag='radius_auth_error')