class RADIUSAcctWorker(TraceMix): """ 记账子进程,处理计费逻辑,把结果推送个 radius 协议处理主进程, 记账是异步处理的,即每次收到记账消息时,立即推送响应,然后在后台异步处理计费逻辑。 """ def __init__(self, config, dbengine, radcache = None): self.config = config self.load_plugins(load_types=['radius_acct_req']) self.db_engine = dbengine or get_engine(config) self.mcache = radcache self.dict = dictionary.Dictionary(os.path.join(os.path.dirname(taurusxradius.__file__), 'dictionarys/dictionary')) self.stat_pusher = ZmqPushConnection(ZmqFactory()) self.zmqrep = ZmqREPConnection(ZmqFactory()) self.stat_pusher.tcpKeepalive = 1 self.zmqrep.tcpKeepalive = 1 self.stat_pusher.addEndpoints([ZmqEndpoint('connect', config.mqproxy.task_connect)]) self.zmqrep.addEndpoints([ZmqEndpoint('connect', config.mqproxy.acct_connect)]) self.zmqrep.gotMessage = self.process self.acct_class = {STATUS_TYPE_START: RadiusAcctStart, STATUS_TYPE_STOP: RadiusAcctStop, STATUS_TYPE_UPDATE: RadiusAcctUpdate, STATUS_TYPE_ACCT_ON: RadiusAcctOnoff, STATUS_TYPE_ACCT_OFF: RadiusAcctOnoff} logger.info('radius acct worker %s start' % os.getpid()) logger.info('init acct worker : %s ' % self.zmqrep) logger.info('init acct stat pusher : %s ' % self.stat_pusher) def do_stat(self, code, status_type = 0, req = None): try: stat_msg = {'statattrs': ['acct_drop'], 'raddata': {}} if code in (4, 5): stat_msg['statattrs'] = [] if code == packet.AccountingRequest: stat_msg['statattrs'].append('acct_req') elif code == packet.AccountingResponse: stat_msg['statattrs'].append('acct_resp') if status_type == 1: stat_msg['statattrs'].append('acct_start') elif status_type == 2: stat_msg['statattrs'].append('acct_stop') elif status_type == 3: stat_msg['statattrs'].append('acct_update') stat_msg['raddata']['input_total'] = req.get_input_total() stat_msg['raddata']['output_total'] = req.get_output_total() elif status_type == 7: stat_msg['statattrs'].append('acct_on') elif status_type == 8: stat_msg['statattrs'].append('acct_off') self.stat_pusher.push(msgpack.packb(stat_msg)) except: pass def process(self, msgid, message): datagram, host, port = msgpack.unpackb(message) reply = self.processAcct(datagram, host, port) self.zmqrep.reply(msgid, msgpack.packb([reply.ReplyPacket(), host, port])) def createAcctPacket(self, **kwargs): vendor_id = kwargs.pop('vendor_id', 0) acct_message = message.AcctMessage(**kwargs) acct_message.vendor_id = vendor_id for plugin in self.acct_req_plugins: acct_message = plugin.plugin_func(acct_message) return acct_message def processAcct(self, datagram, host, port): try: bas = self.find_nas(host) if not bas: raise PacketError('[Radiusd] :: Dropping packet from unknown host %s' % host) secret, vendor_id = bas['bas_secret'], bas['vendor_id'] req = self.createAcctPacket(packet=datagram, dict=self.dict, secret=six.b(str(secret)), vendor_id=vendor_id) self.log_trace(host, port, req) self.do_stat(req.code, req.get_acct_status_type(), req=req) if self.config.system.debug: logger.debug('[Radiusd] :: Received radius request: %s' % req.format_str()) else: logger.info('[Radiusd] :: Received radius request: %s' % repr(req)) if req.code != packet.AccountingRequest: raise PacketError('non-AccountingRequest packet on authentication socket') if not req.VerifyAcctRequest(): raise PacketError('VerifyAcctRequest error') status_type = req.get_acct_status_type() if status_type in self.acct_class: ticket = req.get_ticket() if not ticket.get('nas_addr'): ticket['nas_addr'] = host acct_func = self.acct_class[status_type](self.db_engine, self.mcache, None, ticket).acctounting reactor.callLater(0.05, acct_func) else: raise ValueError('status_type <%s> not support' % status_type) reply = req.CreateReply() reactor.callLater(0.05, self.log_trace, host, port, req, reply) reactor.callLater(0.05, self.do_stat, reply.code) if self.config.system.debug: logger.debug('[Radiusd] :: Send radius response: %s' % reply.format_str()) else: logger.info('[Radiusd] :: Send radius response: %s' % repr(reply)) return reply except Exception as err: self.do_stat(0) logger.exception(err, tag='radius_acct_drop') return
class RADIUSAuthWorker(TraceMix): """ 认证子进程,处理认证授权逻辑,把结果推送个 radius 协议处理主进程 """ def __init__(self, config, dbengine, radcache = None): self.config = config self.load_plugins(load_types=['radius_auth_req', 'radius_accept']) self.dict = dictionary.Dictionary(os.path.join(os.path.dirname(taurusxradius.__file__), 'dictionarys/dictionary')) self.db_engine = dbengine or get_engine(config) self.aes = utils.AESCipher(key=self.config.system.secret) self.mcache = radcache self.stat_pusher = ZmqPushConnection(ZmqFactory()) self.zmqrep = ZmqREPConnection(ZmqFactory()) self.stat_pusher.tcpKeepalive = 1 self.zmqrep.tcpKeepalive = 1 self.stat_pusher.addEndpoints([ZmqEndpoint('connect', config.mqproxy.task_connect)]) self.zmqrep.addEndpoints([ZmqEndpoint('connect', config.mqproxy.auth_connect)]) self.zmqrep.gotMessage = self.process self.reject_debug = int(self.get_param_value('radius_reject_debug', 0)) == 1 logger.info('radius auth worker %s start' % os.getpid()) logger.info('init auth worker : %s ' % self.zmqrep) logger.info('init auth stat pusher : %s ' % self.stat_pusher) def get_account_bind_nas(self, account_number): def fetch_result(): with self.db_engine.begin() as conn: sql = '\n select bas.ip_addr \n from tr_bas as bas,tr_customer as cus,tr_account as usr,tr_bas_node as bn\n where cus.customer_id = usr.customer_id\n and cus.node_id = bn.node_id\n and bn.bas_id = bas.id\n and usr.account_number = :account_number\n ' cur = conn.execute(_sql(sql), account_number=account_number) ipaddrs = [ addr['ip_addr'] for addr in cur ] return ipaddrs return self.mcache.aget(account_bind_basip_key(account_number), fetch_result, expire=600) def do_stat(self, code): try: stat_msg = {'statattrs': [], 'raddata': {}} if code == packet.AccessRequest: stat_msg['statattrs'].append('auth_req') elif code == packet.AccessAccept: stat_msg['statattrs'].append('auth_accept') elif code == packet.AccessReject: stat_msg['statattrs'].append('auth_reject') else: stat_msg['statattrs'] = ['auth_drop'] self.stat_pusher.push(msgpack.packb(stat_msg)) except: pass def process(self, msgid, message): datagram, host, port = msgpack.unpackb(message) reply = self.processAuth(datagram, host, port) if not reply: return if reply.code == packet.AccessReject: logger.error(u'[Radiusd] :: Send Radius Reject %s' % repr(reply), tag='radius_auth_reject') else: logger.info(u'[Radiusd] :: Send radius response: %s' % repr(reply)) if self.config.system.debug: logger.debug(reply.format_str()) self.zmqrep.reply(msgid, msgpack.packb([reply.ReplyPacket(), host, port])) self.do_stat(reply.code) def createAuthPacket(self, **kwargs): vendor_id = kwargs.pop('vendor_id', 0) auth_message = message.AuthMessage(**kwargs) auth_message.vendor_id = vendor_id for plugin in self.auth_req_plugins: auth_message = plugin.plugin_func(auth_message) return auth_message def freeReply(self, req): reply = req.CreateReply() reply.vendor_id = req.vendor_id reply['Reply-Message'] = 'user:%s auth success' % req.get_user_name() reply.code = packet.AccessAccept reply_attrs = {'attrs': {}} reply_attrs['input_rate'] = int(self.get_param_value('radius_free_input_rate', 1048576)) reply_attrs['output_rate'] = int(self.get_param_value('radius_free_output_rate', 4194304)) reply_attrs['rate_code'] = self.get_param_value('radius_free_rate_code', 'freerate') reply_attrs['domain'] = self.get_param_value('radius_free_domain', 'freedomain') reply_attrs['attrs']['Session-Timeout'] = int(self.get_param_value('radius_max_session_timeout', 86400)) for plugin in self.auth_accept_plugins: reply = plugin.plugin_func(reply, reply_attrs) return reply def rejectReply(self, req, errmsg = ''): reply = req.CreateReply() reply.vendor_id = req.vendor_id reply['Reply-Message'] = errmsg reply.code = packet.AccessReject return reply def processAuth(self, datagram, host, port): try: bas = self.find_nas(host) if not bas: raise PacketError('[Radiusd] :: Dropping packet from unknown host %s' % host) secret, vendor_id = bas['bas_secret'], bas['vendor_id'] req = self.createAuthPacket(packet=datagram, dict=self.dict, secret=six.b(str(secret)), vendor_id=vendor_id) username = req.get_user_name() bypass = int(self.get_param_value('radius_bypass', 1)) if req.code != packet.AccessRequest: raise PacketError('non-AccessRequest packet on authentication socket') self.log_trace(host, port, req) self.do_stat(req.code) if self.config.system.debug: logger.debug('[Radiusd] :: Received radius request: %s' % req.format_str()) else: logger.info('[Radiusd] :: Received radius request: %s' % repr(req)) if bypass == 2: reply = self.freeReply(req) self.log_trace(host, port, req, reply) return reply if not self.user_exists(username): errmsg = u'[Radiusd] :: user:%s not exists' % username reply = self.rejectReply(req, errmsg) self.log_trace(host, port, req, reply) return reply bind_nas_list = self.get_account_bind_nas(username) if not bind_nas_list or host not in bind_nas_list: errmsg = u'[Radiusd] :: nas_addr:%s not bind for user:%s node' % (host, username) reply = self.rejectReply(req, errmsg) self.log_trace(host, port, req, reply) return reply aaa_request = dict(account_number=username, domain=req.get_domain(), macaddr=req.client_mac, nasaddr=req.get_nas_addr(), vlanid1=req.vlanid1, vlanid2=req.vlanid2, bypass=bypass, radreq=req) auth_resp = RadiusAuth(self.db_engine, self.mcache, self.aes, aaa_request).authorize() if auth_resp['code'] > 0: reply = self.rejectReply(req, auth_resp['msg']) self.log_trace(host, port, req, reply) return reply reply = req.CreateReply() reply.code = packet.AccessAccept reply.vendor_id = req.vendor_id reply['Reply-Message'] = 'user:%s auth success' % username reply_attrs = {} reply_attrs.update(auth_resp) reply_attrs.update(req.resp_attrs) for plugin in self.auth_accept_plugins: reply = plugin.plugin_func(reply, reply_attrs) if not req.VerifyReply(reply): raise PacketError('[Radiusd] :: user:%s auth verify reply error' % username) self.log_trace(host, port, req, reply) return reply except Exception as err: if not self.reject_debug: self.do_stat(0) logger.exception(err, tag='radius_auth_error') else: reply = self.rejectReply(req, repr(err)) self.log_trace(host, port, req, reply) return reply