Esempio n. 1
0
class SampleTask(CCTask):
    x_time = None

    log = skytools.getLogger('cc.task.sample_async')

    def process_task(self, task):
        self.started = time.time()
        self.timer_handler(1)  # launch asynchronous feedback thread
        # execute long running step
        time.sleep(5)
        # task done
        self.timer.cancel()  # stop asynchronous feedback thread
        self.log.info('task %i done', task['task_id'])
        self.send_finished()

    def timer_handler(self, init=0):
        t = time.time()
        if not init:
            ed = t - self.x_time
            et = t - self.started
            fb = {'elapsed_delta': ed, 'elapsed_total': et}
            self.send_feedback(fb)
        self.x_time = t
        self.timer = threading.Timer(1, self.timer_handler)
        self.timer.start()
Esempio n. 2
0
    def __init__(self,
                 name,
                 xtx,
                 zctx,
                 ioloop,
                 dealer_url,
                 router_url,
                 params={}):
        super(TailWriter_Worker, self).__init__(name=name)

        self.log = skytools.getLogger('h:TailWriter_Worker' +
                                      name[name.rfind('-'):])
        #self.log = skytools.getLogger (self.log.name + name[name.rfind('-'):])
        #self.log = skytools.getLogger (name)

        self.xtx = xtx
        self.zctx = zctx
        self.ioloop = ioloop
        self.shared_url = dealer_url
        self.direct_url = router_url

        for k, v in params.items():
            self.log.trace("setattr: %s -> %r", k, v)
            setattr(self, k, v)

        self.files = {}
        self.looping = True
Esempio n. 3
0
class InfoWriter(BaseProxyHandler):
    """ Simply writes to files (with help from workers) """

    CC_ROLES = ['remote']

    log = skytools.getLogger('h:InfoWriter')

    def startup(self):
        super(InfoWriter, self).startup()

        self.workers = []
        self.wparams = {}  # passed to workers

        self.wparams['dstdir'] = self.cf.getfile('dstdir')
        self.wparams['dstmask'] = self.cf.get('dstmask', '')
        if self.wparams['dstmask'] == '':  # legacy
            if self.cf.getbool('host-subdirs', 0):
                self.wparams['dstmask'] = '%(hostname)s/%(filename)s'
            else:
                self.wparams['dstmask'] = '%(hostname)s--%(filename)s'
        self.wparams['bakext'] = self.cf.get('bakext', '')
        self.wparams['write_compressed'] = self.cf.get('write-compressed', '')
        assert self.wparams['write_compressed'] in [
            None, '', 'no', 'keep', 'yes'
        ]
        if self.wparams['write_compressed'] == 'yes':
            self.wparams['compression'] = self.cf.get('compression', '')
            if self.wparams['compression'] not in ('gzip', 'bzip2'):
                self.log.error("unsupported compression: %s",
                               self.wparams['compression'])
            self.wparams['compression_level'] = self.cf.getint(
                'compression-level', '')

    def make_socket(self):
        """ Create socket for sending msgs to workers. """
        url = 'inproc://workers'
        sock = self.zctx.socket(zmq.XREQ)
        port = sock.bind_to_random_port(url)
        self.worker_url = "%s:%d" % (url, port)
        return sock

    def launch_workers(self):
        """ Create and start worker threads. """
        nw = self.cf.getint('worker-threads', 10)
        for i in range(nw):
            wname = "%s.worker-%i" % (self.hname, i)
            self.log.info("starting %s", wname)
            w = InfoWriter_Worker(wname, self.xtx, self.zctx, self.worker_url,
                                  self.wparams)
            w.stat_inc = self.stat_inc  # XXX
            self.workers.append(w)
            w.start()

    def stop(self):
        """ Signal workers to shut down. """
        super(InfoWriter, self).stop()
        self.log.info('stopping')
        for w in self.workers:
            self.log.info("signalling %s", w.name)
            w.stop()
Esempio n. 4
0
class UdpListener (CCDaemon):
    """ UDP server to handle UDP stream. """

    log = skytools.getLogger ('d:UdpListener')

    def reload (self):
        super(UdpListener, self).reload()

        self.listen_host = self.cf.get ('listen-host')
        self.listen_port = self.cf.getint ('listen-port')
        self.stats_period = self.cf.getint ('stats-period', 30)

    def startup (self):
        super(UdpListener, self).startup()

        # plugins should be ready before we start receiving udp stream
        self.init_plugins()

        self.listen_addr = (self.listen_host, self.listen_port)
        self.sock = socket.socket (socket.AF_INET, socket.SOCK_DGRAM)
        self.sock.setblocking (0)
        try:
            self.sock.bind (self.listen_addr)
        except Exception, e:
            self.log.exception ("failed to bind to %s - %s", self.listen_addr, e)
            raise

        self.ioloop = IOLoop.instance()
        callback = functools.partial (self.handle_udp, self.sock)
        self.ioloop.add_handler (self.sock.fileno(), callback, self.ioloop.READ)

        self.timer_stats = PeriodicCallback (self.send_stats, self.stats_period * 1000, self.ioloop)
        self.timer_stats.start()
Esempio n. 5
0
class SkyLogPlugin(CCDaemonPlugin):
    """ SkyLog plugin interface """

    LOG_FORMATS = []  # json, netstr, syslog

    log = skytools.getLogger('d:SkyLog')

    def probe(self, log_fmt):
        if log_fmt not in self.LOG_FORMATS:
            self.log.debug("plugin %s does not support %r formatted messages",
                           self.__class__.__name__, log_fmt)
            return False
        return True

    def init(self, log_fmt):
        assert log_fmt in self.LOG_FORMATS
        self.msg_format = log_fmt

    def process(self, msg):
        m = {
            "json": self.process_json,
            "netstr": self.process_netstr,
            "syslog": self.process_syslog
        }
        m[self.msg_format](msg)

    def process_json(self, msg):
        raise NotImplementedError

    def process_netstr(self, msg):
        raise NotImplementedError

    def process_syslog(self, msg):
        raise NotImplementedError
Esempio n. 6
0
class SampleTask(CCTask):

    log = skytools.getLogger('cc.task.sample')

    def fetch_config(self):

        # crash before daemonizing if requested
        t = self.task_info['task']
        if t['cmd'] == 'crash-launch':
            raise Exception('launch failed')

        return CCTask.fetch_config(self)

    def process_task(self, task):

        # crash during run
        if task['cmd'] == 'crash-run':
            raise Exception('run failed')

        # do some work
        for i in range(3):
            time.sleep(1)
            fb = {'i': i}
            self.send_feedback(fb)
        # task done
        self.log.info('task %s done', task['task_id'])
        self.send_finished()
Esempio n. 7
0
class PgLogForward(UdpListener):
    """ UDP server to handle UDP stream sent by pg_logforward. """

    log = skytools.getLogger('d:PgLogForward')

    def reload(self):
        super(PgLogForward, self).reload()

        self.log_format = self.cf.get('log-format')
        assert self.log_format in ['netstr']
        self.log_parsing_errors = self.cf.getbool('log-parsing-errors', False)

    def _probe_func(self, cls):
        """ Custom plugin probing function """
        if not issubclass(cls, PgLogForwardPlugin):
            self.log.debug("plugin %s is not of supported type", cls.__name__)
            return False
        if self.log_format not in cls.LOG_FORMATS:
            self.log.debug("plugin %s does not support %r formatted messages",
                           cls.__name__, self.log_format)
            return False
        return True

    def init_plugins(self):
        """ Load suitable plugins and initialise them """
        self.load_plugins(log_fmt=self.log_format)
        for p in self.plugins:
            p.init(self.log_format)

    def parse_json(self, data):
        """ Parse JSON datagram sent by pg_logforward """
        raise NotImplementedError

    def parse_netstr(self, data):
        """ Parse netstrings datagram sent by pg_logforward """
        try:
            keys = [
                "elevel", "sqlerrcode", "username", "database", "remotehost",
                "funcname", "message", "detail", "hint", "context",
                "debug_query_string"
            ]
            pos = 0
            res = {}
            while data:
                res[keys[pos]], data = tnetstrings.parse(data)
                pos += 1
            res['elevel'] = int(res['elevel'])
            res['elevel_text'] = pg_elevels_itoa[res['elevel']]
            res['sqlerrcode'] = int(res['sqlerrcode'])
            return res

        except Exception, e:
            if self.log_parsing_errors:
                self.log.warning("netstr parsing error: %s", e)
                self.log.debug("failed netstring: {%s} [%i] %r", keys[pos],
                               len(data), data)
            return None
Esempio n. 8
0
class TaskState(object):
    """ Tracks task state (with help of watchdog) """

    log = skytools.getLogger('d:TaskState')

    def __init__(self, uid, name, info, ioloop, cc, xtx):
        self.uid = uid
        self.name = name
        self.info = info
        self.pidfile = info['config']['pidfile']
        self.ioloop = ioloop
        self.cc = cc
        self.xtx = xtx
        self.timer = None
        self.timer_tick = 1
        self.heartbeat = False
        self.start_time = None
        self.dead_since = None

    def start(self):
        self.start_time = time.time()
        self.timer = PeriodicCallback(self.watchdog, self.timer_tick * 1000,
                                      self.ioloop)
        self.timer.start()

    def stop(self):
        try:
            self.log.info('Signalling %s', self.name)
            skytools.signal_pidfile(self.pidfile, signal.SIGINT)
        except:
            self.log.exception('signal_pidfile failed: %s', self.pidfile)

    def watchdog(self):
        live = skytools.signal_pidfile(self.pidfile, 0)
        if live:
            self.log.debug('%s is alive', self.name)
            if self.heartbeat:
                self.send_reply('running')
        else:
            self.log.info('%s is over', self.name)
            self.dead_since = time.time()
            self.timer.stop()
            self.timer = None
            self.send_reply('stopped')

    def ccpublish(self, msg):
        assert isinstance(msg, TaskReplyMessage)
        cmsg = self.xtx.create_cmsg(msg)
        cmsg.send_to(self.cc)

    def send_reply(self, status, feedback={}):
        msg = TaskReplyMessage(req='task.reply.%s' % self.uid,
                               handler=self.info['task']['task_handler'],
                               task_id=self.info['task']['task_id'],
                               status=status,
                               feedback=feedback)
        self.ccpublish(msg)
Esempio n. 9
0
 def __init__(self, name, xtx, zctx, url, connstr, func_list):
     super(DBWorker, self).__init__(name=name)
     self.log = skytools.getLogger (name)
     self.xtx = xtx
     self.zctx = zctx
     self.master_url = url
     self.connstr = connstr
     self.func_list = func_list
     self.db = None
     self.master = None
     self.looping = True
Esempio n. 10
0
class TaskInfo:
    """Per-task state, replies."""

    log = skytools.getLogger('TaskInfo')

    def __init__(self, task, task_cbfunc, ccrq):
        self.task = task
        self.uuid = task['task_id']
        self.task_cbfunc = task_cbfunc
        self.replies = []
        self.ccrq = ccrq
        self.retry_count = 3
        self.retry_timeout = 15
        self.query_id = None

    def send_task(self):
        """Send the task away."""
        self.log.debug('')
        self.query_id = self.ccrq.ccquery_async(self.task, self.process_reply, self.retry_timeout)

    def process_reply(self, msg):
        """Main processing logic.

        If msg==None, then timeout occured.

        Returns tuple of (keep, timeout) to CCReqStream.
        """

        if msg is None:
            if self.retry_count > 0:
                self.log.warning('timeout, resending')
                self.retry_count -= 1
                self.ccrq.resend(self.query_id)
                return (True, self.retry_timeout)
            self.log.error('timeout, task failed')
            self.task_cbfunc(True, msg)
            return (False, 0)

        tup = msg.req.split('.')
        if tup[0] == 'error':
            done = True
            self.log.error('got error: %r', msg)
        elif tup[:2] == ['task', 'reply']:
            done = msg.status in ('finished', 'failed', 'stopped')
            self.log.debug('got result: %r', msg)
        else:
            done = False
            self.log.debug('got random: %r', msg)
        self.replies.append(msg)
        self.task_cbfunc(done, msg)
        self.log.debug('done=%r', done)
        if done:
            return (False, 0)
        return (True, 0)
Esempio n. 11
0
 def __init__(self, name, xtx, zctx, url, connstr, func_list):
     super(DBWorker, self).__init__(name=name)
     self.log = skytools.getLogger(name)
     self.xtx = xtx
     self.zctx = zctx
     self.master_url = url
     self.connstr = connstr
     self.func_list = func_list
     self.db = None
     self.master = None
     self.looping = True
Esempio n. 12
0
class Disposer(CCHandler):
    """ Discards any message received """

    CC_ROLES = ['local', 'remote']

    log = skytools.getLogger('h:Disposer')

    def handle_msg(self, cmsg):
        """ Got message from client -- discard it :-) """
        self.log.trace('')
        self.stat_inc('disposed_count')
        self.stat_inc('disposed_bytes', cmsg.get_size())
Esempio n. 13
0
    def __init__(self, service_type, args):
        # no crypto for logs
        self.logxtx = CryptoContext(None)
        self.xtx = CryptoContext(None)

        super(CCJob, self).__init__(service_type, args)

        self.hostname = socket.gethostname()

        root = skytools.getLogger()
        root.addHandler(CallbackLogger(self.emit_log))

        self.xtx = CryptoContext(self.cf)
Esempio n. 14
0
class LocalLogger(CCHandler):
    """Log as local log msg."""

    CC_ROLES = ['local', 'remote']

    log = skytools.getLogger('h:LocalLogger')

    def handle_msg(self, cmsg):
        msg = cmsg.get_payload(self.xtx)
        if hasattr(msg, 'log_level'):
            lt = time.strftime ("%H:%M:%S,", time.localtime (msg.log_time)) + ("%.3f" % (msg.log_time % 1))[2:]
            self.log.info ('[%s@%s] %s %s %s', msg.job_name, msg.hostname, lt, msg.log_level, msg.log_msg)
        else:
            self.log.info ('non-log msg: %r', msg)
Esempio n. 15
0
    def __init__ (self, name, xtx, zctx, url, params = {}):
        super(InfoWriter_Worker, self).__init__(name=name)

        self.log = skytools.getLogger ('h:InfoWriter_Worker' + name[name.rfind('-'):])
        #self.log = skytools.getLogger (self.log.name + name[name.rfind('-'):])
        #self.log = skytools.getLogger (name)

        self.xtx = xtx
        self.zctx = zctx
        self.master_url = url

        for k, v in params.items():
            self.log.trace ("setattr: %s -> %r", k, v)
            setattr (self, k, v)

        self.looping = True
Esempio n. 16
0
File: filter.py Progetto: markokr/cc
class Filter(CCHandler):
    """ Filters received messages, then dispatches them to another handler. """

    CC_ROLES = ['local', 'remote']

    log = skytools.getLogger('h:Filter')

    def __init__(self, hname, hcf, ccscript):
        super(Filter, self).__init__(hname, hcf, ccscript)

        self.fwd_hname = self.cf.get('forward-to')
        self.fwd_handler = ccscript.get_handler(self.fwd_hname)

        self.includes = _hint_list(self.cf.getlist('include', []))
        self.excludes = _hint_list(self.cf.getlist('exclude', []))

    def handle_msg(self, cmsg):
        """ Got message from client -- process it.
        """
        dest = cmsg.get_dest()
        size = cmsg.get_size()
        stat = '?'

        for exc, wild in self.excludes:
            if (not wild and dest == exc) or fnmatch.fnmatchcase(dest, exc):
                stat = 'dropped'
                break
        else:
            if self.includes:
                for inc, wild in self.includes:
                    if (not wild and dest == inc) or fnmatch.fnmatchcase(
                            dest, inc):
                        break
                else:
                    stat = 'dropped'
            if stat != 'dropped':
                try:
                    self.fwd_handler.handle_msg(cmsg)
                    stat = 'ok'
                except Exception:
                    self.log.exception('crashed, dropping msg: %s', dest)
                    stat = 'crashed'

        self.stat_inc('filter.count')
        self.stat_inc('filter.bytes', size)
        self.stat_inc('filter.count.%s' % stat)
        self.stat_inc('filter.bytes.%s' % stat, size)
Esempio n. 17
0
class SkyLog(UdpListener):
    """ UDP server to handle UDP stream sent by skytools' skylog. """

    log = skytools.getLogger('d:SkyLog')

    def reload(self):
        super(SkyLog, self).reload()

        self.log_format = self.cf.get('log-format')
        assert self.log_format in ['netstr']
        self.log_parsing_errors = self.cf.getbool('log-parsing-errors', False)

    def _probe_func(self, cls):
        """ Custom plugin probing function """
        if not issubclass(cls, SkyLogPlugin):
            self.log.debug("plugin %s is not of supported type", cls.__name__)
            return False
        if self.log_format not in cls.LOG_FORMATS:
            self.log.debug("plugin %s does not support %r formatted messages",
                           cls.__name__, self.log_format)
            return False
        return True

    def init_plugins(self):
        """ Load suitable plugins and initialise them """
        self.load_plugins(log_fmt=self.log_format)
        for p in self.plugins:
            p.init(self.log_format)

    def parse_json(self, data):
        """ Parse JSON datagram sent by skylog """
        raise NotImplementedError

    def parse_netstr(self, data):
        """ Parse tnetstrings datagram sent by skylog """
        try:
            msg, rest = tnetstrings.parse(data)
            if rest:
                self.log.warning("netstr parsing leftover: %r", rest)
                self.log.debug("failed tnetstring: [%i] %r", len(data), data)
            return msg

        except Exception, e:
            if self.log_parsing_errors:
                self.log.warning("netstr parsing error: %s", e)
                self.log.debug("failed tnetstring: [%i] %r", len(data), data)
            return None
Esempio n. 18
0
File: jobmgr.py Progetto: postsql/cc
class JobMgr(CCHandler):
    """Provide config to local daemons / tasks."""

    log = skytools.getLogger('h:JobMgr')

    CC_ROLES = ['local']

    def __init__(self, hname, hcf, ccscript):
        super(JobMgr, self).__init__(hname, hcf, ccscript)

        self.cc_config = ccscript.args[0]

        self.local_url = ccscript.local_url
        self.cc_job_name = ccscript.job_name

        self.job_args_extra = []
        if ccscript.options.quiet:
            self.job_args_extra.append("-q")
        if ccscript.options.verbose:
            self.job_args_extra.extend(["-v"] * ccscript.options.verbose)

        self.jobs = {}
        for dname in self.cf.getlist('daemons'):
            defs = make_job_defaults(ccscript.cf, dname)
            self.add_job(dname, defs)

        self.xtx = CryptoContext(None)

    def add_job(self, jname, defs):
        jcf = skytools.Config(jname, self.cf.filename, user_defs=defs)
        j = JobState(jname, jcf, self.local_url, self.ioloop, self.xtx)
        self.jobs[jname] = j
        j.start(self.job_args_extra)

    def handle_msg(self, cmsg):
        """ Got message from client, answer it. """

        self.log.warning('JobMgr req: %s', cmsg)
        return

    def stop(self):
        super(JobMgr, self).stop()
        self.log.info('Stopping CC daemons')
        for j in self.jobs.values():
            self.log.debug("stopping %s", j.jname)
            j.stop()
Esempio n. 19
0
    def __init__(self, name, xtx, zctx, url, params={}):
        super(InfoWriter_Worker, self).__init__(name=name)

        self.log = skytools.getLogger('h:InfoWriter_Worker' +
                                      name[name.rfind('-'):])
        #self.log = skytools.getLogger (self.log.name + name[name.rfind('-'):])
        #self.log = skytools.getLogger (name)

        self.xtx = xtx
        self.zctx = zctx
        self.master_url = url

        for k, v in params.items():
            self.log.trace("setattr: %s -> %r", k, v)
            setattr(self, k, v)

        self.looping = True
Esempio n. 20
0
class Delay (CCHandler):
    """ Delays all received messages, then dispatches them to another handler. """

    CC_ROLES = ['local', 'remote']

    log = skytools.getLogger ('h:Delay')

    tick = 250 # ms

    def __init__ (self, hname, hcf, ccscript):
        super(Delay, self).__init__(hname, hcf, ccscript)

        self.fwd_hname = self.cf.get ('forward-to')
        self.delay = self.cf.getint ('delay', 0)

        self.fwd_handler = ccscript.get_handler (self.fwd_hname)
        self.queue = collections.deque()

        self.timer = PeriodicCallback (self.process_queue, self.tick, self.ioloop)
        self.timer.start()

    def handle_msg (self, cmsg):
        """ Got message from client -- queue it """
        self.queue.append ((time.time() + self.delay, cmsg))

    def process_queue (self):
        now = time.time()
        try:
            while (self.queue[0][0] <= now):
                at, cmsg = self.queue.popleft()
                size = cmsg.get_size()
                try:
                    self.fwd_handler.handle_msg (cmsg)
                    stat = 'ok'
                except Exception:
                    self.log.exception ('crashed, dropping msg: %s', cmsg.get_dest())
                    stat = 'crashed'
                self.stat_inc ('delay.count')
                self.stat_inc ('delay.bytes', size)
                self.stat_inc ('delay.count.%s' % stat)
                self.stat_inc ('delay.bytes.%s' % stat, size)
        except IndexError:
            pass
Esempio n. 21
0
    def __init__ (self, name, xtx, zctx, ioloop, dealer_url, router_url, params = {}):
        super(TailWriter_Worker, self).__init__(name=name)

        self.log = skytools.getLogger ('h:TailWriter_Worker' + name[name.rfind('-'):])
        #self.log = skytools.getLogger (self.log.name + name[name.rfind('-'):])
        #self.log = skytools.getLogger (name)

        self.xtx = xtx
        self.zctx = zctx
        self.ioloop = ioloop
        self.shared_url = dealer_url
        self.direct_url = router_url

        for k, v in params.items():
            self.log.trace ("setattr: %s -> %r", k, v)
            setattr (self, k, v)

        self.files = {}
        self.looping = True
Esempio n. 22
0
class DBHandler (BaseProxyHandler):
    """Send request to workers."""

    CC_ROLES = ['remote']

    log = skytools.getLogger('h:DBHandler')

    def startup (self):
        super(DBHandler, self).startup()
        self.workers = []

    def make_socket (self):
        """ Create socket for sending msgs to workers. """
        url = 'inproc://workers'
        sock = self.zctx.socket (zmq.XREQ)
        port = sock.bind_to_random_port (url)
        self.worker_url = "%s:%d" % (url, port)
        return sock

    def launch_workers(self):
        """ Create and start worker threads. """
        nworkers = self.cf.getint('worker-threads', 10)
        func_list = self.cf.getlist('allowed-functions', [])
        self.log.info('allowed-functions: %r', func_list)
        connstr = self.cf.get('db')
        for i in range(nworkers):
            wname = "%s.worker-%i" % (self.hname, i)
            self.log.info ('starting %s', wname)
            w = DBWorker(
                    wname, self.xtx, self.zctx, self.worker_url,
                    connstr, func_list)
            self.workers.append (w)
            w.start()

    def stop (self):
        """ Signal workers to shut down. """
        super(DBHandler, self).stop()
        self.log.info ("stopping")
        for w in self.workers:
            self.log.info ("signalling %s", w.name)
            w.stop()
Esempio n. 23
0
class QueryInfo:
    """Store callback details for query."""
    log = skytools.getLogger('QueryInfo')

    def __init__(self, qid, cmsg, cbfunc, rqs):
        self.qid = qid
        self.orig_cmsg = cmsg
        self.cbfunc = cbfunc
        self.timeout_ref = None
        self.ioloop = rqs.ioloop
        self.remove_query = rqs.remove_query

    def on_timeout(self):
        """Called by ioloop on timeout, needs to handle exceptions"""
        try:
            self.timeout_ref = None
            self.launch_cb(None)
        except:
            self.log.exception('timeout callback crashed')

    def launch_cb(self, arg):
        """Run callback, re-wire timeout and query if needed."""
        keep, timeout = self.cbfunc(arg)
        self.log.trace('keep=%r', keep)
        if keep:
            self.set_timeout(timeout)
        else:
            self.remove_query(self.qid)

    def set_timeout(self, timeout):
        """Set new timeout for task, None means drop it"""
        if self.timeout_ref:
            self.ioloop.remove_timeout(self.timeout_ref)
            self.timeout_ref = None
        if timeout:
            deadline = time.time() + timeout
            self.timeout_ref = self.ioloop.add_timeout(deadline,
                                                       self.on_timeout)

    def send_to(self, cc):
        self.orig_cmsg.send_to(cc)
Esempio n. 24
0
class PgLogForward(CCDaemon):
    """ UDP server to handle UDP stream sent by pg_logforward. """

    log = skytools.getLogger('d:PgLogForward')

    def reload(self):
        super(PgLogForward, self).reload()

        self.listen_host = self.cf.get('listen-host')
        self.listen_port = self.cf.getint('listen-port')
        self.log_format = self.cf.get('log-format')
        assert self.log_format in ['netstr']
        self.log_parsing_errors = self.cf.getbool('log-parsing-errors', False)
        self.stats_period = self.cf.getint('stats-period', 30)

    def startup(self):
        super(PgLogForward, self).startup()

        # plugins should be ready before we start receiving udp stream
        self.load_plugins(log_fmt=self.log_format)
        for p in self.plugins:
            p.init(self.log_format)

        self.listen_addr = (self.listen_host, self.listen_port)
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.sock.setblocking(0)
        try:
            self.sock.bind(self.listen_addr)
        except Exception, e:
            self.log.exception("failed to bind to %s - %s", self.listen_addr,
                               e)
            raise

        self.ioloop = IOLoop.instance()
        callback = functools.partial(self.handle_udp, self.sock)
        self.ioloop.add_handler(self.sock.fileno(), callback, self.ioloop.READ)

        self.timer_stats = PeriodicCallback(self.send_stats,
                                            self.stats_period * 1000,
                                            self.ioloop)
        self.timer_stats.start()
Esempio n. 25
0
class CCDaemon(CCJob):
    log = skytools.getLogger('d:CCDaemon')

    def find_plugins(self, mod_name, probe_func=None):
        """ plugin lookup helper """
        p = []
        __import__(mod_name)
        m = sys.modules[mod_name]
        for an in dir(m):
            av = getattr(m, an)
            if (isinstance(av, types.TypeType)
                    and issubclass(av, CCDaemonPlugin)
                    and av.__module__ == m.__name__):
                if not probe_func or probe_func(av):
                    p += [av]
                else:
                    self.log.debug("plugin %s probing negative", an)
        if not p:
            self.log.info("no suitable plugins found in %s", mod_name)
        return p

    def load_plugins(self, *args, **kwargs):
        """ Look for suitable plugins, probe them, load them.
        """
        self.plugins = []
        for palias in self.cf.getlist('plugins'):
            pcf = self.cf.clone(palias)
            mod = pcf.get('module')
            for cls in self.find_plugins(mod):
                pin = cls(palias, pcf, self)
                if pin.probe(*args, **kwargs):
                    self.plugins += [pin]
                else:
                    self.log.debug("plugin %s probing negative",
                                   pin.__class__.__name__)
        if self.plugins:
            self.log.info("Loaded plugins: %s",
                          [p.__class__.__name__ for p in self.plugins])
        else:
            self.log.warn("No plugins loaded!")
Esempio n. 26
0
class CCHandler(object):
    """Basic handler interface."""

    log = skytools.getLogger('h:CCHandler')

    def __init__(self, hname, hcf, ccscript):
        """Store handler config."""
        self.hname = hname
        self.cf = hcf
        self.xtx = ccscript.xtx
        self.zctx = ccscript.zctx
        self.ioloop = ccscript.ioloop
        self.cclocal = ccscript.local
        self.stat_inc = ccscript.stat_inc

    def handle_msg(self, rmsg):
        """Process single message"""
        raise NotImplementedError

    def stop(self):
        """Called on process shutdown."""
        pass
Esempio n. 27
0
class TaskManager:
    """Manages task on single CCReqStream connection."""

    log = skytools.getLogger('TaskManager')

    def __init__(self, ccrq):
        self.ccrq = ccrq
        self.ioloop = ccrq.ioloop

    def send_task_async(self, task, task_cbfunc):
        """Async task launch.

        Callback function will be called on replies.

        @param task: TaskSendMessage
        @param task_cbfunc: func with args (is_done, reply_msg)
        """
        assert isinstance(task, TaskSendMessage)

        self.log.debug('(%r, %r)', task, task_cbfunc)
        ti = TaskInfo(task, task_cbfunc, self.ccrq)
        ti.send_task()
        return ti

    def send_task(self, task):
        """Sync task launch.

        Returns TaskInfo with replies when finished.
        """
        assert isinstance(task, TaskSendMessage)

        self.log.debug('(%r, %r)', task)
        def cb(done, rep):
            if done:
                self.ioloop.stop()
        ti = self.send_task_async(task, cb)
        self.ioloop.start()
        return ti
Esempio n. 28
0
class Echo(CCHandler):
    """ Echo handler / sender / monitor """

    CC_ROLES = ['local', 'remote']

    log = skytools.getLogger('h:Echo')

    ping_tick = 1
    zmq_hwm = 1
    zmq_linger = 0

    def __init__(self, hname, hcf, ccscript):
        super(Echo, self).__init__(hname, hcf, ccscript)

        self.echoes = {}  # echo stats for monitored peers
        self.stream = {}  # connections to monitored peers

        for url in self.cf.getlist("ping-remotes", ""):
            sock = self._make_socket(url)
            self.stream[url] = CCStream(sock,
                                        ccscript.ioloop,
                                        qmaxsize=self.zmq_hwm)
            self.stream[url].on_recv(self.on_recv)
            self.echoes[url] = EchoState(url)
            self.log.debug("will ping %s", url)

        self.timer = PeriodicCallback(self.ping, self.ping_tick * 1000,
                                      self.ioloop)
        self.timer.start()

    def _make_socket(self, url):
        """ Create socket for pinging remote CC. """
        sock = self.zctx.socket(zmq.XREQ)
        try:
            sock.setsockopt(zmq.HWM, self.zmq_hwm)
        except AttributeError:
            sock.set_hwm(self.zmq_hwm)
        sock.setsockopt(zmq.LINGER, self.zmq_linger)
        sock.connect(url)
        return sock

    def on_recv(self, zmsg):
        """ Got reply from a remote CC, process it. """
        try:
            self.log.trace("%r", zmsg)
            cmsg = CCMessage(zmsg)
            req = cmsg.get_dest()
            if req == "echo.response":
                self.process_response(cmsg)
            else:
                self.log.warn("unknown msg: %s", req)
        except:
            self.log.exception("crashed, dropping msg")

    def handle_msg(self, cmsg):
        """ Got a message, process it. """

        self.log.trace("%r", cmsg)
        req = cmsg.get_dest()

        if req == "echo.request":
            self.process_request(cmsg)
        else:
            self.log.warn("unknown msg: %s", req)

    def process_request(self, cmsg):
        """ Ping received, respond with pong. """

        msg = cmsg.get_payload(self.xtx)
        if not msg: return

        rep = EchoResponseMessage(orig_hostname=msg['hostname'],
                                  orig_target=msg['target'],
                                  orig_time=msg['time'])
        rcm = self.xtx.create_cmsg(rep)
        rcm.take_route(cmsg)
        rcm.send_to(self.cclocal)

    def process_response(self, cmsg):
        """ Pong received, evaluate it. """

        msg = cmsg.get_payload(self.xtx)
        if not msg: return

        url = msg.orig_target
        if url not in self.echoes:
            self.log.warn("unknown pong: %s", url)
            return
        echo = self.echoes[url]
        echo.update_pong(msg)

        rtt = echo.time_pong - msg.orig_time
        if msg.orig_time == echo.time_ping:
            self.log.trace("echo time: %f s (%s)", rtt, url)
        elif rtt <= 5 * self.ping_tick:
            self.log.debug("late pong: %f s (%s)", rtt, url)
        else:
            self.log.info("too late pong: %f s (%s)", rtt, url)

    def send_request(self, url):
        """ Send ping to remote CC. """
        msg = EchoRequestMessage(target=url)
        cmsg = self.xtx.create_cmsg(msg)
        self.stream[url].send_cmsg(cmsg)
        self.echoes[url].update_ping(msg)
        self.log.trace("%r", msg)

    def ping(self):
        """ Echo requesting and monitoring. """
        self.log.trace("")
        for url in self.stream:
            echo = self.echoes[url]
            if echo.time_ping - echo.time_pong > 5 * self.ping_tick:
                self.log.warn("no pong from %s for %f s", url,
                              echo.time_ping - echo.time_pong)
            self.send_request(url)

    def stop(self):
        super(Echo, self).stop()
        self.log.info("stopping")
        self.timer.stop()
Esempio n. 29
0
class CryptoContext:
    """Load crypto config, check messages based on it."""

    log = skytools.getLogger('CryptoContext')

    def __init__(self, cf):
        if not cf:
            self.cms = None
            self.ks_dir = ''
            self.ks = KeyStore('', '')
            self.ca_name = None
            self.decrypt_name = None
            self.encrypt_name = None
            self.sign_name = None
            self.time_window = 0
            return
        self.ks_dir = cf.getfile('cms-keystore', '')
        priv_dir = os.path.join(self.ks_dir, 'private')
        ks = KeyStore(priv_dir, self.ks_dir)

        self.cms = CMSTool(ks)
        self.ca_name = cf.get('cms-verify-ca', '')
        self.decrypt_name = cf.get('cms-decrypt', '')
        self.sign_name = cf.get('cms-sign', '')
        self.encrypt_name = cf.get('cms-encrypt', '')
        self.time_window = int(cf.get('cms-time-window', '0'))

    def fill_config(self, cf_dict):
        pairs = (('cms-verify-ca', 'ca_name'),
                 ('cms-decrypt', 'decrypt_name'),
                 ('cms-sign', 'sign_name'),
                 ('cms-encrypt', 'encrypt_name'),
                 ('cms-keystore', 'ks_dir'))
        for n1, n2 in pairs:
            v = getattr(self, n2)
            if v and n1 not in cf_dict:
                cf_dict[n1] = v

    def create_cmsg(self, msg, blob=None):
        if blob is not None and self.sign_name:
            msg.blob_hash = "SHA-1:" + sha1(blob).hexdigest()
        js = msg.dump_json()
        part1 = js
        part2 = ''
        if self.encrypt_name and self.sign_name:
            self.log.trace("encrypt: %s", msg['req'])
            part1 = 'ENC1'
            part2 = self.cms.sign_and_encrypt(js, self.sign_name, self.encrypt_name)
        elif self.encrypt_name:
            raise Exception('encrypt_name without sign_name ?')
        elif self.sign_name:
            self.log.trace("sign: %s", msg['req'])
            part2 = self.cms.sign(js, self.sign_name)
        else:
            self.log.trace("no crypto: %s", msg['req'])
        zmsg = ['', msg.req.encode('utf8'), part1, part2]
        if blob is not None:
            zmsg.append(blob)
        return CCMessage(zmsg)

    def parse_cmsg(self, cmsg):
        req = cmsg.get_dest()
        part1 = cmsg.get_part1()
        part2 = cmsg.get_part2()
        blob = cmsg.get_part3()

        if self.decrypt_name:
            if part1 != 'ENC1':
                self.log.error('Expect encrypted message')
                return (None, None)
            if not self.decrypt_name or not self.ca_name:
                self.log.error('Cannot decrypt message')
                return (None, None)
            self.log.trace("decrypt: %s", req)
            js, sgn = self.cms.decrypt_and_verify(part2, self.decrypt_name, self.ca_name)
        elif part1 == 'ENC1':
            self.log.error('Got encrypted msg but cannot decrypt it')
            return (None, None)
        elif self.ca_name:
            if not part2:
                self.log.error('Expect signed message: %r', part1)
                return (None, None)
            self.log.trace("verify: %s", req)
            js, sgn = self.cms.verify(part1, part2, self.ca_name)
        else:
            self.log.trace("no crypto: %s", req)
            js, sgn = part1, None

        msg = Struct.from_json(js)
        if msg.req != req:
            self.log.error ('hijacked message')
            return (None, None)

        if self.time_window:
            age = time.time() - msg.time
            if abs(age) > self.time_window:
                self.log.error('time diff bigger than %d s', self.time_window)
                return (None, None)

        if blob is not None:
            if not self.ca_name and not part2:
                if getattr(msg, 'blob_hash', None):
                    self.log.debug ('blob hash ignored')
            elif getattr(msg, 'blob_hash', None):
                ht, hs, hv = msg.blob_hash.partition(':')
                if ht == 'SHA-1':
                    bh = sha1(blob).hexdigest()
                else:
                    self.log.error ('unsupported hash type: %s', ht)
                    return (None, None)
                if bh != hv:
                    self.log.error ('blob hash does not match: %s <> %s', bh, hv)
                    return (None, None)
            else:
                self.log.error ('blob hash missing')
                return (None, None)
        elif msg.get('blob_hash', None):
            self.log.error ('blob hash exists without blob')
            return (None, None)
        return msg, sgn
Esempio n. 30
0
File: server.py Progetto: postsql/cc
class CCServer(skytools.BaseScript):
    """Listens on single ZMQ sockets, dispatches messages to handlers.

    Config::
        ## Parameters for CCServer ##

        # listening socket for this CC instance
        cc-socket = tcp://127.0.0.1:22632

        # zmq customization:
        #zmq_nthreads = 1
        #zmq_linger = 500
        #zmq_hwm = 100

        #zmq_tcp_keepalive = 1
        #zmq_tcp_keepalive_intvl = 15
        #zmq_tcp_keepalive_idle = 240
        #zmq_tcp_keepalive_cnt = 4
    """
    extra_ini = """
    Extra segments::

        # map req prefix to handler segment
        [routes]
        log = h:locallog

        # segment for specific handler
        [h:locallog]
        handler = cc.handler.locallogger
    """

    log = skytools.getLogger('CCServer')

    cf_defaults = {
        'logfmt_console': LOG.fmt,
        'logfmt_file': LOG.fmt,
        'logfmt_console_verbose': LOG.fmt_v,
        'logfmt_file_verbose': LOG.fmt_v,
        'logdatefmt_console': LOG.datefmt,
        'logdatefmt_file': LOG.datefmt,
        'logdatefmt_console_verbose': LOG.datefmt_v,
        'logdatefmt_file_verbose': LOG.datefmt_v,
    }

    __version__ = __version__

    stat_level = 1

    zmq_nthreads = 1
    zmq_linger = 500
    zmq_hwm = 100
    zmq_rcvbuf = 0  # means no change
    zmq_sndbuf = 0  # means no change

    zmq_tcp_keepalive = 1
    zmq_tcp_keepalive_intvl = 15
    zmq_tcp_keepalive_idle = 4 * 60
    zmq_tcp_keepalive_cnt = 4

    def reload(self):
        super(CCServer, self).reload()

        self.zmq_nthreads = self.cf.getint('zmq_nthreads', self.zmq_nthreads)
        self.zmq_hwm = self.cf.getint('zmq_hwm', self.zmq_hwm)
        self.zmq_linger = self.cf.getint('zmq_linger', self.zmq_linger)
        self.zmq_rcvbuf = hsize_to_bytes(
            self.cf.get('zmq_rcvbuf', str(self.zmq_rcvbuf)))
        self.zmq_sndbuf = hsize_to_bytes(
            self.cf.get('zmq_sndbuf', str(self.zmq_sndbuf)))

        self.zmq_tcp_keepalive = self.cf.getint('zmq_tcp_keepalive',
                                                self.zmq_tcp_keepalive)
        self.zmq_tcp_keepalive_intvl = self.cf.getint(
            'zmq_tcp_keepalive_intvl', self.zmq_tcp_keepalive_intvl)
        self.zmq_tcp_keepalive_idle = self.cf.getint(
            'zmq_tcp_keepalive_idle', self.zmq_tcp_keepalive_idle)
        self.zmq_tcp_keepalive_cnt = self.cf.getint('zmq_tcp_keepalive_cnt',
                                                    self.zmq_tcp_keepalive_cnt)

    def print_ini(self):
        super(CCServer, self).print_ini()

        self._print_ini_frag(self.extra_ini)

    def startup(self):
        """Setup sockets and handlers."""

        super(CCServer, self).startup()

        self.log.info("C&C server version %s starting up..", self.__version__)

        self.xtx = CryptoContext(self.cf)
        self.zctx = zmq.Context(self.zmq_nthreads)
        self.ioloop = IOLoop.instance()

        self.local_url = self.cf.get('cc-socket')

        self.cur_role = self.cf.get('cc-role', 'insecure')
        if self.cur_role == 'insecure':
            self.log.warning(
                'CC is running in insecure mode, please add "cc-role = local" or "cc-role = remote" option to config'
            )

        self.stat_level = self.cf.getint('cc-stats', 1)
        if self.stat_level < 1:
            self.log.warning('CC statistics level too low: %d',
                             self.stat_level)

        # initialize local listen socket
        s = self.zctx.socket(zmq.XREP)
        s.setsockopt(zmq.LINGER, self.zmq_linger)
        s.setsockopt(zmq.HWM, self.zmq_hwm)
        if self.zmq_rcvbuf > 0:
            s.setsockopt(zmq.RCVBUF, self.zmq_rcvbuf)
        if self.zmq_sndbuf > 0:
            s.setsockopt(zmq.SNDBUF, self.zmq_sndbuf)
        if self.zmq_tcp_keepalive > 0:
            if getattr(zmq, 'TCP_KEEPALIVE', -1) > 0:
                s.setsockopt(zmq.TCP_KEEPALIVE, self.zmq_tcp_keepalive)
                s.setsockopt(zmq.TCP_KEEPALIVE_INTVL,
                             self.zmq_tcp_keepalive_intvl)
                s.setsockopt(zmq.TCP_KEEPALIVE_IDLE,
                             self.zmq_tcp_keepalive_idle)
                s.setsockopt(zmq.TCP_KEEPALIVE_CNT, self.zmq_tcp_keepalive_cnt)
            else:
                self.log.info("TCP_KEEPALIVE not available")
        s.bind(self.local_url)
        self.local = CCStream(s, self.ioloop, qmaxsize=self.zmq_hwm)
        self.local.on_recv(self.handle_cc_recv)

        self.handlers = {}
        self.routes = {}
        rcf = skytools.Config('routes', self.cf.filename, ignore_defs=True)
        for r, hnames in rcf.cf.items('routes'):
            self.log.info('New route: %s = %s', r, hnames)
            for hname in [hn.strip() for hn in hnames.split(',')]:
                h = self.get_handler(hname)
                self.add_handler(r, h)

        self.stimer = PeriodicCallback(self.send_stats, 30 * 1000, self.ioloop)
        self.stimer.start()

    def send_stats(self):
        if self.stat_level == 0:
            return

        # make sure we have something to send
        self.stat_increase('count', 0)

        # combine our stats with global stats
        self.combine_stats(reset_stats())

        super(CCServer, self).send_stats()

    def combine_stats(self, other):
        for k, v in other.items():
            self.stat_inc(k, v)

    def get_handler(self, hname):
        if hname in self.handlers:
            h = self.handlers[hname]
        else:
            hcf = self.cf.clone(hname)

            # renamed option: plugin->handler
            htype = hcf.get('plugin', '?')
            if htype == '?':
                htype = hcf.get('handler')

            cls = cc_handler_lookup(htype, self.cur_role)
            h = cls(hname, hcf, self)
            self.handlers[hname] = h
        return h

    def add_handler(self, rname, handler):
        """Add route to handler"""

        if rname == '*':
            r = ()
        else:
            r = tuple(rname.split('.'))
        self.log.debug('New route for handler: %r -> %s', r, handler.hname)
        rhandlers = self.routes.setdefault(r, [])
        rhandlers.append(handler)

    def handle_cc_recv(self, zmsg):
        """Got message from client, pick handler."""

        start = time.time()
        self.stat_inc('count')
        self.log.trace('got msg: %r', zmsg)
        try:
            cmsg = CCMessage(zmsg)
        except:
            self.log.exception('Invalid CC message')
            self.stat_increase('count.invalid')
            return

        try:
            dst = cmsg.get_dest()
            size = cmsg.get_size()
            route = tuple(dst.split('.'))

            # find and run all handlers that match
            cnt = 0
            for n in range(0, 1 + len(route)):
                p = route[:n]
                for h in self.routes.get(p, []):
                    self.log.trace('calling handler %s', h.hname)
                    h.handle_msg(cmsg)
                    cnt += 1
            if cnt == 0:
                self.log.warning('dropping msg, no route: %s', dst)
                stat = 'dropped'
            else:
                stat = 'ok'

        except Exception:
            self.log.exception('crashed, dropping msg: %s', dst)
            stat = 'crashed'

        # update stats
        taken = time.time() - start
        self.stat_inc('bytes', size)
        self.stat_inc('seconds', taken)
        self.stat_inc('count.%s' % stat)
        self.stat_inc('bytes.%s' % stat, size)
        self.stat_inc('seconds.%s' % stat, taken)
        if self.stat_level > 1:
            self.stat_inc('count.%s.msg.%s' % (stat, dst))
            self.stat_inc('bytes.%s.msg.%s' % (stat, dst), size)
            self.stat_inc('seconds.%s.msg.%s' % (stat, dst), taken)

    def work(self):
        """Default work loop simply runs ioloop."""
        self.set_single_loop(1)
        self.log.info('Starting IOLoop')
        try:
            self.ioloop.start()
        except zmq.ZMQError, d:
            # ZMQ gets surprised by EINTR
            if d.errno == errno.EINTR:
                return 1
            raise
Esempio n. 31
0
class CCReqStream:
    """Request-based API for CC socket.

    Add request-id into route, later map replies to original request
    based on that.
    """

    log = skytools.getLogger('CCReqStream')

    zmq_hwm = 100
    zmq_linger = 500

    def __init__(self, cc_url, xtx, ioloop=None, zctx=None):
        """Initialize stream."""

        zctx = zctx or zmq.Context.instance()
        ioloop = ioloop or IOLoop.instance()

        s = zctx.socket(zmq.XREQ)
        try:
            s.setsockopt(zmq.HWM, self.zmq_hwm)
        except AttributeError:
            s.set_hwm(self.zmq_hwm)
        s.setsockopt(zmq.LINGER, self.zmq_linger)
        s.connect(cc_url)

        self.ccs = CCStream(s, ioloop, qmaxsize=self.zmq_hwm)
        self.ioloop = ioloop
        self.xtx = xtx

        self.query_id_seq = 1
        self.query_cache = {}

        self.ccs.on_recv(self.handle_recv)

    def remove_query(self, qid):
        """Drop query state.  Further replies are ignored."""
        qi = self.query_cache.get(qid)
        if qi:
            del self.query_cache[qid]
            qi.set_timeout(None)

    def ccquery_sync(self, msg, timeout=0):
        """Synchronous query.

        Returns first reply.
        """
        res = [None]

        def sync_cb(_rep):
            res[0] = _rep
            self.ioloop.stop()
            return (False, 0)

        self.ccquery_async(msg, sync_cb, timeout)
        self.ioloop.start()
        return res[0]

    def ccquery_async(self, msg, cbfunc, timeout=0):
        """Asynchronous query.

        Maps replies to callback function based on request id.
        """
        # create query id prefix
        qid = "Q%06d" % self.query_id_seq
        self.query_id_seq += 1

        # create message, add query id
        cmsg = self.xtx.create_cmsg(msg)
        cmsg.set_route([qid])

        qi = QueryInfo(qid, cmsg, cbfunc, self)
        self.query_cache[qid] = qi

        qi.set_timeout(timeout)

        qi.send_to(self.ccs)

        return qid

    def ccpublish(self, msg):
        """Broadcast API."""
        cmsg = self.xtx.create_cmsg(msg)
        cmsg.send_to(self.ccs)

    def handle_recv(self, zmsg):
        """Internal callback on ZMQStream.

        It must not throw exceptions.
        """
        try:
            self.handle_recv_real(zmsg)
        except Exception:
            self.log.exception('handle_recv_real crashed, dropping msg: %r',
                               zmsg)

    def handle_recv_real(self, zmsg):
        """Actual callback that can throw exceptions."""

        cmsg = CCMessage(zmsg)

        route = cmsg.get_route()
        if len(route) != 1:
            self.log.error('Invalid reply route: %r', route)
            return

        qid = route[0]
        if qid not in self.query_cache:
            self.log.error('reply for unknown query: %r', qid)
            return

        msg = cmsg.get_payload(self.xtx)

        qi = self.query_cache[qid]
        qi.launch_cb(msg)

    def resend(self, qid, timeout=0):
        if qid in self.query_cache:
            qi = self.query_cache[qid]
            qi.send_to(self.ccs)
            qi.set_timeout(timeout)
        else:
            pass  # ?
Esempio n. 32
0
class DBWorker(threading.Thread):
    """Worker thread, can do blocking calls."""

    log = skytools.getLogger('h:DBWorker')

    def __init__(self, name, xtx, zctx, url, connstr, func_list):
        super(DBWorker, self).__init__(name=name)
        self.log = skytools.getLogger(name)
        self.xtx = xtx
        self.zctx = zctx
        self.master_url = url
        self.connstr = connstr
        self.func_list = func_list
        self.db = None
        self.master = None
        self.looping = True

    def startup(self):
        self.master = self.zctx.socket(zmq.XREP)
        self.master.connect(self.master_url)
        self.poller = zmq.Poller()
        self.poller.register(self.master, zmq.POLLIN)

    def run(self):
        self.log.info("%s running", self.name)
        self.startup()
        while self.looping:
            try:
                self.work()
            except:
                self.log.exception('worker crash, dropping msg')
                self.reset()
                time.sleep(10)
        self.shutdown()

    def reset(self):
        try:
            if self.db:
                self.db.close()
                self.db = None
        except:
            pass

    def stop(self):
        self.looping = False

    def shutdown(self):
        self.log.info("%s stopping", self.name)
        self.reset()

    def work(self):
        socks = dict(self.poller.poll(1000))
        if self.master in socks and socks[self.master] == zmq.POLLIN:
            zmsg = self.master.recv_multipart()
        else:  # timeout
            return
        try:
            cmsg = CCMessage(zmsg)
            self.log.trace('%s', cmsg)
        except:
            self.log.exception("invalid CC message")
            return

        if not self.db:
            self.log.info('connecting to database')
            self.db = skytools.connect_database(self.connstr)
            self.db.set_isolation_level(0)

        self.process_request(cmsg)

    def process_request(self, cmsg):
        msg = cmsg.get_payload(self.xtx)
        if not msg:
            return
        curs = self.db.cursor()
        func = msg.function
        args = msg.get('params', [])
        if isinstance(args, StringType):
            args = cc.json.loads(args)
        assert isinstance(args, (DictType, ListType, TupleType))

        if len(self.func_list) == 1 and self.func_list[0] == '*':
            pass
        elif func in self.func_list:
            pass
        else:
            self.log.error('Function call not allowed: %r', func)
            return None

        q = "select %s (%%s)" % (skytools.quote_fqident(func), )
        if isinstance(args, DictType):
            if not all([re.match("^[a-zA-Z0-9_]+$", k) for k in args.keys()]):
                self.log.error("Invalid DB function argument name in %r",
                               args.keys())
                return
            q %= (", ".join(["%s := %%(%s)s" % (k, k) for k in args.keys()]), )
        else:
            q %= (", ".join(["%s" for a in args]), )
        if self.log.isEnabledFor(skytools.skylog.TRACE):
            self.log.trace('Executing: %s', curs.mogrify(q, args))
        else:
            self.log.debug('Executing: %s', q)
        curs.execute(q, args)

        rt = msg.get('return')
        if rt in (None, '', 'no'):
            return
        elif rt == 'all':
            rs = curs.fetchall()
        elif rt == 'one':
            rs = curs.fetchone()
        elif rt == 'json':
            rs = curs.fetchone()
            if rs:
                jsr = rs[0]
            else:
                jsr = '{}'
            rep = parse_json(jsr)
        if rt != 'json':
            rep = ReplyMessage(req="reply.%s" % msg.req, data=rs)
            if curs.rowcount >= 0:
                rep.rowcount = curs.rowcount
            if curs.statusmessage:
                rep.statusmessage = curs.statusmessage
            if msg.get('ident'):
                rep.ident = msg.get('ident')

        rcm = self.xtx.create_cmsg(rep)
        rcm.take_route(cmsg)
        rcm.send_to(self.master)
Esempio n. 33
0
class LogfileTailer(CCDaemon):
    """ Logfile tailer for rotated log files """

    log = skytools.getLogger('d:LogfileTailer')

    BUF_MINBYTES = 64 * 1024
    PROBESLEFT = 2  # number of retries after old log EOF and new log spotted

    def reload(self):
        super(LogfileTailer, self).reload()

        self.op_mode = self.cf.get('operation-mode', '')
        if self.op_mode not in (None, '', 'classic', 'rotated'):
            self.log.error("unknown operation-mode: %s", self.op_mode)

        self.file_mode = self.cf.get('file-mode', '')
        if self.file_mode not in (None, '', 'text', 'binary'):
            self.log.error("unknown file-mode: %s", self.file_mode)

        self.logdir = self.cf.getfile('logdir')
        if self.op_mode in (None, '', 'classic'):
            self.logmask = self.cf.get('logmask')
        elif self.op_mode == 'rotated':
            self.logname = self.cf.get('logname')
            if re.search('\?|\*', self.logname):
                self.log.error("wildcards in logname not supported: %s",
                               self.logname)
            self.logmask = self.logname

        self.compression = self.cf.get('compression', '')
        if self.compression not in (None, '', 'none', 'gzip', 'bzip2'):
            self.log.error("unknown compression: %s", self.compression)
        self.compression_level = self.cf.getint('compression-level', '')
        self.msg_suffix = self.cf.get('msg-suffix', '')
        if self.msg_suffix and not is_msg_req_valid(self.msg_suffix):
            self.log.error("invalid msg-suffix: %s", self.msg_suffix)
            self.msg_suffix = None
        self.use_blob = self.cf.getbool('use-blob', True)
        self.lag_maxbytes = cc.util.hsize_to_bytes(
            self.cf.get('lag-max-bytes', '0'))

        self.reverse_sort = False
        self.buf_maxbytes = cc.util.hsize_to_bytes(
            self.cf.get('buffer-bytes', '0'))
        self.buf_maxlines = self.cf.getint('buffer-lines', -1)
        self.buf_maxdelay = 1.0

        # compensate for our config class weakness
        if self.buf_maxbytes <= 0: self.buf_maxbytes = None
        if self.buf_maxlines < 0: self.buf_maxlines = None
        # set defaults if nothing found in config
        if self.buf_maxbytes is None and self.buf_maxlines is None:
            self.buf_maxbytes = 1024 * 1024

        if self.compression not in (None, '', 'none'):
            if self.buf_maxbytes < self.BUF_MINBYTES:
                self.log.info("buffer-bytes too low, adjusting: %i -> %i",
                              self.buf_maxbytes, self.BUF_MINBYTES)
                self.buf_maxbytes = self.BUF_MINBYTES

    def startup(self):
        super(LogfileTailer, self).startup()

        self.logfile = None  # full path
        self.logf = None  # file object
        self.logfpos = None  # tell()
        self.probesleft = self.PROBESLEFT
        self.first = True
        self.tailed_files = 0
        self.tailed_bytes = 0
        self.buffer = cStringIO.StringIO()
        self.buflines = 0
        self.bufseek = None
        self.saved_fpos = None
        self.save_file = None
        self.logf_dev = self.logf_ino = None

        sfn = self.get_save_filename()
        try:
            with open(sfn, "r") as f:
                s = f.readline().split('\t', 1)
                try:
                    self.logfile = s[1].strip()
                    self.saved_fpos = int(s[0])
                    self.log.info("found saved state for %s", self.logfile)
                except:
                    self.logfile = self.saved_fpos = None

            if self.op_mode == 'rotated':
                self.log.info("cannot use saved state in this operation mode")
                self.logfile = self.saved_fpos = None

            lag = self.count_lag_bytes()
            if lag is not None:
                self.log.info("currently lagging %i bytes behind", lag)
                if lag > self.lag_maxbytes:
                    self.log.warning("lag too big, skipping")
                    self.logfile = self.saved_fpos = None
            else:
                self.log.warning("cannot determine lag, skipping")
                self.logfile = self.saved_fpos = None
        except IOError:
            pass
        self.save_file = open(sfn, "a")

    def count_lag_bytes(self):
        files = self.get_all_filenames()
        if self.logfile not in files or self.saved_fpos is None:
            return None
        lag = 0
        while True:
            fn = files.pop()
            st = os.stat(fn)
            lag += st.st_size
            if (fn == self.logfile):
                break
        lag -= self.saved_fpos
        assert lag >= 0
        return lag

    def get_all_filenames(self):
        """ Return sorted list of all log file names """
        lfni = glob.iglob(os.path.join(self.logdir, self.logmask))
        lfns = sorted(lfni, reverse=self.reverse_sort)
        return lfns

    def get_last_filename(self):
        """ Return the name of latest log file """
        files = self.get_all_filenames()
        if files:
            return files[-1]
        return None

    def get_next_filename(self):
        """ Return the name of "next" log file """
        files = self.get_all_filenames()
        if not files:
            return None
        try:
            i = files.index(self.logfile)
            if not self.first:
                fn = files[i + 1]
            else:
                fn = files[i]
        except ValueError:
            fn = files[-1]
        except IndexError:
            fn = files[i]
        return fn

    def get_save_filename(self):
        """ Return the name of save file """
        return os.path.splitext(self.pidfile)[0] + ".save"

    def save_file_pos(self):
        self.save_file.truncate(0)
        self.save_file.write("%i\t%s" % (self.bufseek, self.logfile))
        self.log.debug("saved offset %i for %s", self.bufseek, self.logfile)

    def is_new_file_available(self):
        if self.op_mode in (None, '', 'classic'):
            return (self.logfile != self.get_next_filename())
        elif self.op_mode == 'rotated':
            st = os.stat(self.logfile)
            return (st.st_dev != self.logf_dev or st.st_ino != self.logf_ino)
        else:
            raise ValueError("unsupported mode of operation")

    def try_open_file(self, name):
        """ Try open log file; sleep a bit if unavailable. """
        if name:
            assert self.buffer.tell() == 0
            try:
                self.logf = open(name, 'rb')
                self.logfile = name
                self.logfpos = 0
                self.bufseek = 0
                self.send_stats()  # better do it async me think (?)
                self.log.info("Tailing %s", self.logfile)
                self.stat_inc('tailed_files')
                self.tailed_files += 1
                self.probesleft = self.PROBESLEFT
                st = os.fstat(self.logf.fileno())
                self.logf_dev, self.logf_ino = st.st_dev, st.st_ino
            except IOError, e:
                self.log.info("%s", e)
                time.sleep(0.2)
        else: