Exemple #1
0
 def start_telemetry_callback(self) -> None:
     """
     Run telemetry callback
     :return:
     """
     self.telemetry_callback = PeriodicCallback(self.send_telemetry, 250)
     self.telemetry_callback.start()
Exemple #2
0
 async def create_session(self):
     """
     Create consul session
     :return:
     """
     self.logger.info("Creating session")
     checks = ["serfHealth"]
     while True:
         try:
             self.session = await self.consul.session.create(
                 name=self.name,
                 checks=checks,
                 behavior="delete",
                 lock_delay=self.DEFAULT_CONSUL_LOCK_DELAY,
                 ttl=self.DEFAULT_CONSUL_SESSION_TTL,
             )
             break
         except ConsulRepeatableErrors:
             await asyncio.sleep(self.DEFAULT_CONSUL_RETRY_TIMEOUT)
             continue
     self.logger.info("Session id: %s", self.session)
     self.keep_alive_task = PeriodicCallback(
         self.keep_alive, self.DEFAULT_CONSUL_SESSION_TTL * 1000 / 2
     )
     self.keep_alive_task.start()
Exemple #3
0
 async def start(self):
     """
     Start all pending tasks
     :return:
     """
     self.thread_id = threading.get_ident()
     self.resolver_expiration_task = PeriodicCallback(self.expire_resolvers, 10000)
     self.resolver_expiration_task.start()
Exemple #4
0
 async def on_activate(self):
     report_callback = PeriodicCallback(self.report, 10000)
     report_callback.start()
     check_callback = PeriodicCallback(self.check_channels, config.chwriter.batch_delay_ms)
     check_callback.start()
     self.logger.info("Sending records to %s", self.ch_address)
     asyncio.create_task(self.subscribe_ch_streams())
     asyncio.create_task(self.flush_data())
Exemple #5
0
 async def on_activate(self):
     """
     Load rules from database after loading config
     """
     self.logger.info("Using rule lookup solution: %s",
                      config.classifier.lookup_handler)
     self.ruleset.load()
     self.patternset.load()
     self.load_triggers()
     self.load_link_action()
     self.load_handlers()
     # Heat up MIB cache
     MIBData.preload()
     self.slot_number, self.total_slots = await self.acquire_slot()
     await self.subscribe_stream("events.%s" % config.pool,
                                 self.slot_number, self.on_event)
     report_callback = PeriodicCallback(self.report, 1000)
     report_callback.start()
Exemple #6
0
 async def on_activate(self):
     # Listen sockets
     server = SyslogServer(service=self)
     for addr, port in server.iter_listen(config.syslogcollector.listen):
         self.logger.info("Starting syslog server at %s:%s", addr, port)
         try:
             server.listen(port, addr)
         except OSError as e:
             metrics["error", ("type", "socket_listen_error")] += 1
             self.logger.error("Failed to start syslog server at %s:%s: %s",
                               addr, port, e)
     server.start()
     # Report invalid sources every 60 seconds
     self.logger.info("Stating invalid sources reporting task")
     self.report_invalid_callback = PeriodicCallback(
         self.report_invalid_sources, 60000)
     self.report_invalid_callback.start()
     # Start tracking changes
     asyncio.get_running_loop().create_task(self.get_object_mappings())
Exemple #7
0
class BaseService(object):
    """
    Basic service implementation.

    * on_change_<var> - subscribed to changes of config variable <var>
    """

    # Service name
    name = None
    # Leader lock name
    # Only one active instace per leader lock can be active
    # at given moment
    # Config variables can be expanded as %(varname)s
    leader_lock_name = None

    # Leader group name
    # Only one service in leader group can be running at a time
    # Config variables can be expanded as %(varname)s
    # @todo: Deprecated, must be removed
    leader_group_name = None
    # Pooled service are used to distribute load between services.
    # Pool name set in NOC_POOL parameter or --pool option.
    # May be used in conjunction with leader_group_name
    # to allow only one instance of services per node or datacenter
    pooled = False

    # Format string to set process name
    # config variables can be expanded as %(name)s
    process_name = "noc-%(name).10s"
    # Connect to MongoDB on activate
    use_mongo = False
    # Initialize gettext and process *language* configuration
    use_translation = False
    # Initialize jinja2 templating engine
    use_jinja = False
    # Collect and send spans
    use_telemetry = False
    # Register traefik backend if not None
    traefik_backend = None
    # Traefik frontend rule
    # i.e. PathPrefix:/api/<name>
    traefik_frontend_rule = None
    # Require DCS health status to be considered healthy
    # Usually means resolution error to required services
    # temporary leads service to unhealthy state
    require_dcs_health = True

    LOG_FORMAT = config.log_format

    LOG_LEVELS = {
        "critical": logging.CRITICAL,
        "error": logging.ERROR,
        "warning": logging.WARNING,
        "info": logging.INFO,
        "debug": logging.DEBUG,
    }

    DEFAULT_SHARDING_KEY = "managed_object"

    SHARDING_KEYS = {"span": "ctx"}

    # Timeout to wait NSQ writer is to close
    NSQ_WRITER_CLOSE_TRY_TIMEOUT = 0.25
    # Times to try to close NSQ writer
    NSQ_WRITER_CLOSE_RETRIES = 5

    class RegistrationError(Exception):
        pass

    def __init__(self):
        set_service(self)
        sys.excepthook = excepthook
        self.loop: Optional[asyncio.BaseEventLoop] = None
        self.logger = None
        self.service_id = str(uuid.uuid4())
        self.executors = {}
        self.start_time = perf_counter()
        self.pid = os.getpid()
        self.nsq_readers = {}  # handler -> Reader
        self.nsq_writer = None
        self.startup_ts = None
        self.telemetry_callback = None
        self.dcs = None
        # Effective address and port
        self.address = None
        self.port = None
        self.is_active = False
        # Can be initialized in subclasses
        self.scheduler = None
        # NSQ Topics
        # name -> TopicQueue()
        self.topic_queues: Dict[str, TopicQueue] = {}
        self.topic_queue_lock = threading.Lock()
        # Liftbridge publisher
        self.publish_queue: Optional[LiftBridgeQueue] = None
        self.publisher_start_lock = threading.Lock()
        #
        self.active_subscribers = 0
        self.subscriber_shutdown_waiter: Optional[asyncio.Event] = None
        # Metrics partitions
        self.n_metrics_partitions = len(config.clickhouse.cluster_topology.split(","))
        #
        self.metrics_key_lock = threading.Lock()
        self.metrics_key_seq: int = 0

    def create_parser(self) -> argparse.ArgumentParser:
        """
        Return argument parser
        """
        return argparse.ArgumentParser()

    def add_arguments(self, parser: argparse.ArgumentParser) -> None:
        """
        Apply additional parser arguments
        """
        parser.add_argument(
            "--node", action="store", dest="node", default=config.node, help="NOC node name"
        )
        parser.add_argument(
            "--loglevel",
            action="store",
            choices=list(self.LOG_LEVELS),
            dest="loglevel",
            default=config.loglevel,
            help="Logging level",
        )
        parser.add_argument(
            "--instance",
            action="store",
            dest="instance",
            type=int,
            default=config.instance,
            help="Instance number",
        )
        parser.add_argument(
            "--debug",
            action="store_true",
            dest="debug",
            default=False,
            help="Dump additional debugging info",
        )
        parser.add_argument(
            "--dcs",
            action="store",
            dest="dcs",
            default=DEFAULT_DCS,
            help="Distributed Coordinated Storage URL",
        )
        if self.pooled:
            parser.add_argument(
                "--pool", action="store", dest="pool", default=config.pool, help="NOC pool name"
            )

    @classmethod
    def die(cls, msg: str = "") -> NoReturn:
        """
        Dump message to stdout and terminate process with error code
        """
        sys.stdout.write(str(msg) + "\n")
        sys.stdout.flush()
        os._exit(1)

    def setup_logging(self, loglevel=None):
        """
        Create new or setup existing logger
        """
        # @todo: Duplicates config.setup_logging
        if not loglevel:
            loglevel = config.loglevel
        logger = logging.getLogger()
        if len(logger.handlers):
            # Logger is already initialized
            fmt = ErrorFormatter(self.LOG_FORMAT, None)
            for h in logging.root.handlers:
                if isinstance(h, logging.StreamHandler):
                    h.stream = sys.stdout
                h.setFormatter(fmt)
            logging.root.setLevel(loglevel)
        else:
            # Initialize logger
            logging.basicConfig(stream=sys.stdout, format=self.LOG_FORMAT, level=loglevel)
        self.logger = logging.getLogger(self.name)
        logging.captureWarnings(True)

    def setup_test_logging(self):
        self.logger = logging.getLogger(self.name)

    def setup_translation(self):
        from noc.core.translation import set_translation, ugettext

        set_translation(self.name, config.language)
        if self.use_jinja:
            from jinja2.defaults import DEFAULT_NAMESPACE

            if "_" not in DEFAULT_NAMESPACE:
                DEFAULT_NAMESPACE["_"] = ugettext

    def on_change_loglevel(self, old_value, new_value):
        if new_value not in self.LOG_LEVELS:
            self.logger.error("Invalid loglevel '%s'. Ignoring", new_value)
            return
        self.logger.warning("Changing loglevel to %s", new_value)
        logging.getLogger().setLevel(self.LOG_LEVELS[new_value])

    def log_separator(self, symbol="*", length=72):
        """
        Log a separator string to visually split log
        """
        self.logger.warning(symbol * length)
        if config.features.forensic:
            self.logger.warning("[noc.core.forensic] [=Process restarted]")

    def setup_signal_handlers(self):
        """
        Set up signal handlers
        """
        signal.signal(signal.SIGTERM, self.on_SIGTERM)
        signal.signal(signal.SIGHUP, self.on_SIGHUP)

    def set_proc_title(self):
        """
        Set process title
        """
        v = {"name": self.name, "instance": config.instance or "", "pool": config.pool or ""}
        title = self.process_name % v
        self.logger.debug("Setting process title to: %s", title)
        setproctitle.setproctitle(title)

    def start(self):
        """
        Run main server loop
        """
        self.startup_ts = perf_counter()
        parser = self.create_parser()
        self.add_arguments(parser)
        options = parser.parse_args(sys.argv[1:])
        cmd_options = vars(options)
        cmd_options.pop("args", ())
        # Bootstrap logging with --loglevel
        self.setup_logging(cmd_options["loglevel"])
        self.log_separator()
        # Setup timezone
        try:
            self.logger.info("Setting timezone to %s", config.timezone)
            setup_timezone()
        except ValueError as e:
            self.die(str(e))
        # Setup title
        self.set_proc_title()
        # Setup signal handlers
        self.setup_signal_handlers()
        self.on_start()
        # Starting IOLoop
        self.is_active = True
        if self.pooled:
            self.logger.warning("Running service %s (pool: %s)", self.name, config.pool)
        else:
            self.logger.warning("Running service %s", self.name)
        try:
            setup_asyncio()
            self.loop = asyncio.get_event_loop()
            # Initialize DCS
            self.dcs = get_dcs(cmd_options["dcs"])
            # Activate service
            self.loop.create_task(self.activate())
            self.logger.warning("Starting IOLoop")
            self.loop.run_forever()
        except KeyboardInterrupt:
            self.logger.warning("Interrupted by Ctrl+C")
        except self.RegistrationError:
            self.logger.info("Registration failed")
        except Exception:
            error_report()
        finally:
            if self.loop:
                self.loop.create_task(self.deactivate())
        self.logger.warning("Service %s has been terminated", self.name)

    def get_event_loop(self) -> asyncio.BaseEventLoop:
        return self.loop

    def on_start(self):
        """
        Reload config
        """
        if self.use_translation:
            self.setup_translation()

    def stop(self):
        self.logger.warning("Stopping")
        self.loop.create_task(self.deactivate())

    def on_SIGHUP(self, signo, frame):
        # self.logger.warning("SIGHUP caught, rereading config")
        pass

    def on_SIGTERM(self, signo, frame):
        self.logger.warning("SIGTERM caught, Stopping")
        self.stop()

    def get_service_address(self) -> Tuple[str, int]:
        """
        Returns an (address, port) for HTTP service listener
        """
        if self.address and self.port:
            return self.address, self.port
        if config.listen:
            addr, port = config.listen.split(":")
            port_tracker = config.instance
        else:
            addr, port = "auto", 0
            port_tracker = 0
        if addr == "auto":
            addr = config.node
            self.logger.info("Autodetecting address: auto -> %s", addr)
        addr = config.node
        port = int(port) + port_tracker
        return addr, port

    async def init_api(self):
        """
        Initialize API routers and handlers
        :return:
        """
        raise NotImplementedError

    async def shutdown_api(self):
        """
        Stop API services
        :return:
        """
        raise NotImplementedError

    async def activate(self):
        """
        Initialize services before run
        """
        self.logger.warning("Activating service")
        if self.use_mongo:
            from noc.core.mongo.connection import connect

            connect()

        await self.init_api()
        #
        if self.use_telemetry:
            self.start_telemetry_callback()
        self.loop.create_task(self.on_register())

    async def deactivate(self):
        if not self.is_active:
            return
        self.is_active = False
        self.logger.info("Deactivating")
        # Shutdown API
        await self.shutdown_api()
        # Release registration
        if self.dcs:
            self.logger.info("Deregistration")
            await self.dcs.deregister()
        # Shutdown schedulers
        if self.scheduler:
            try:
                self.logger.info("Shutting down scheduler")
                await self.scheduler.shutdown()
            except asyncio.TimeoutError:
                self.logger.info("Timed out when shutting down scheduler")
        # Shutdown subscriptions
        await self.shutdown_subscriptions()
        # Shutdown executors
        await self.shutdown_executors()
        # Custom deactivation
        await self.on_deactivate()
        # Shutdown NSQ topics
        await self.shutdown_topic_queues()
        # Shutdown Liftbridge publisher
        await self.shutdown_publisher()
        # Continue deactivation
        # Finally stop ioloop
        self.dcs = None
        self.logger.info("Stopping EventLoop")
        self.loop.stop()
        m = {}
        apply_metrics(m)
        apply_hists(m)
        apply_quantiles(m)
        self.logger.info("Post-mortem metrics: %s", m)
        self.die("")

    def get_register_tags(self):
        tags = ["noc"]
        if config.features.traefik:
            if self.traefik_backend and self.traefik_frontend_rule:
                tags += [
                    "traefik.tags=backend",
                    "traefik.backend=%s" % self.traefik_backend,
                    "traefik.frontend.rule=%s" % self.traefik_frontend_rule,
                    "traefik.backend.load-balancing=wrr",
                ]
                weight = self.get_backend_weight()
                if weight:
                    tags += ["traefik.backend.weight=%s" % weight]
                limit = self.get_backend_limit()
                if limit:
                    tags += ["traefik.backend.maxconn.amount=%s" % limit]
        return tags

    async def on_register(self):
        addr, port = self.get_service_address()
        r = await self.dcs.register(
            self.name,
            addr,
            port,
            pool=config.pool if self.pooled else None,
            lock=self.get_leader_lock_name(),
            tags=self.get_register_tags(),
        )
        if r:
            # Finally call on_activate
            await self.on_activate()
            self.logger.info("Service is active (in %.2fms)", self.uptime() * 1000)
        else:
            raise self.RegistrationError()

    async def on_activate(self):
        """
        Called when service activated
        """
        return

    async def acquire_lock(self):
        await self.dcs.acquire_lock("lock-%s" % self.name)

    async def acquire_slot(self):
        if self.pooled:
            name = "%s-%s" % (self.name, config.pool)
        else:
            name = self.name
        slot_number, total_slots = await self.dcs.acquire_slot(name, config.global_n_instances)
        if total_slots <= 0:
            self.die("Service misconfiguration detected: Invalid total_slots")
        return slot_number, total_slots

    async def on_deactivate(self):
        return

    def open_rpc(self, name, pool=None, sync=False, hints=None):
        """
        Returns RPC proxy object.
        """
        if pool:
            svc = "%s-%s" % (name, pool)
        else:
            svc = name
        return RPCProxy(self, svc, sync=sync, hints=hints)

    def get_mon_status(self):
        return True

    def get_mon_data(self):
        """
        Returns monitoring data
        """
        r = {
            "status": self.get_mon_status(),
            "service": self.name,
            "instance": str(self.service_id),
            "node": config.node,
            "pid": self.pid,
            # Current process uptime
            "uptime": perf_counter() - self.start_time,
        }
        if self.pooled:
            r["pool"] = config.pool
        if self.executors:
            for x in self.executors:
                self.executors[x].apply_metrics(r)
        apply_metrics(r)
        for topic in self.topic_queues:
            self.topic_queues[topic].apply_metrics(r)
        if self.publish_queue:
            self.publish_queue.apply_metrics(r)
        apply_hists(r)
        apply_quantiles(r)
        return r

    def iter_rpc_retry_timeout(self):
        """
        Yield timeout to wait after unsuccessful RPC connection
        """
        for t in config.rpc.retry_timeout.split(","):
            yield float(t)

    async def subscribe_stream(
        self,
        stream: str,
        partition: int,
        handler: Callable[
            [Message],
            Awaitable[None],
        ],
        start_timestamp: Optional[float] = None,
        start_position: StartPosition = StartPosition.RESUME,
        cursor_id: Optional[str] = None,
        auto_set_cursor: bool = True,
    ) -> None:
        # @todo: Restart on failure
        self.logger.info("Subscribing %s:%s", stream, partition)
        cursor_id = cursor_id or self.name
        try:
            async with LiftBridgeClient() as client:
                self.active_subscribers += 1
                async for msg in client.subscribe(
                    stream=stream,
                    partition=partition,
                    start_position=start_position,
                    cursor_id=self.name,
                    start_timestamp=start_timestamp,
                ):
                    try:
                        await handler(msg)
                    except Exception as e:
                        self.logger.error("Failed to process message: %s", e)
                    if auto_set_cursor and cursor_id:
                        await client.set_cursor(
                            stream=stream,
                            partition=partition,
                            cursor_id=cursor_id,
                            offset=msg.offset,
                        )
                    if self.subscriber_shutdown_waiter:
                        break
        finally:
            self.active_subscribers -= 1
        if self.subscriber_shutdown_waiter and not self.active_subscribers:
            self.subscriber_shutdown_waiter.set()

    async def subscribe(self, topic, channel, handler, raw=False, **kwargs):
        """
        Subscribe message to channel
        """

        def call_json_handler(message):
            metrics[metric_in] += 1
            try:
                data = orjson.loads(message.body)
            except ValueError as e:
                metrics[metric_decode_fail] += 1
                self.logger.debug("Cannot decode JSON message: %s", e)
                return True  # Broken message
            if isinstance(data, dict):
                with ErrorReport():
                    r = handler(message, **data)
            else:
                with ErrorReport():
                    r = handler(message, data)
            if r:
                metrics[metric_processed] += 1
            elif message.is_async():
                message.on("finish", on_finish)
            else:
                metrics[metric_deferred] += 1
            return r

        def call_raw_handler(message):
            metrics[metric_in] += 1
            with ErrorReport():
                r = handler(message, message.body)
            if r:
                metrics[metric_processed] += 1
            elif message.is_async():
                message.on("finish", on_finish)
            else:
                metrics[metric_deferred] += 1
            return r

        def on_finish(*args, **kwargs):
            metrics[metric_processed] += 1

        t = topic.replace(".", "_")
        metric_in = "nsq_msg_in_%s" % t
        metric_decode_fail = "nsq_msg_decode_fail_%s" % t
        metric_processed = "nsq_msg_processed_%s" % t
        metric_deferred = "nsq_msg_deferred_%s" % t
        lookupd = [str(a) for a in config.nsqlookupd.http_addresses]
        self.logger.info("Subscribing to %s/%s (lookupd: %s)", topic, channel, ", ".join(lookupd))
        self.nsq_readers[handler] = NSQReader(
            message_handler=call_raw_handler if raw else call_json_handler,
            topic=topic,
            channel=channel,
            lookupd_http_addresses=lookupd,
            snappy=config.nsqd.compression == "snappy",
            deflate=config.nsqd.compression == "deflate",
            deflate_level=config.nsqd.compression_level
            if config.nsqd.compression == "deflate"
            else 6,
            **kwargs,
        )

    def suspend_subscription(self, handler):
        """
        Suspend subscription for handler
        :param handler:
        :return:
        """
        self.logger.info("Suspending subscription for handler %s", handler)
        self.nsq_readers[handler].set_max_in_flight(0)

    def resume_subscription(self, handler):
        """
        Resume subscription for handler
        :param handler:
        :return:
        """
        self.logger.info("Resuming subscription for handler %s", handler)
        self.nsq_readers[handler].set_max_in_flight(config.nsqd.max_in_flight)

    def _init_publisher(self):
        """
        Spin-up publisher and queue
        :return:
        """
        with self.publisher_start_lock:
            if self.publish_queue:
                return  # Created in concurrent thread
            self.publish_queue = LiftBridgeQueue(self.loop)
            self.loop.create_task(self.publisher())

    def publish(
        self,
        value: bytes,
        stream: str,
        partition: Optional[int] = None,
        key: Optional[bytes] = None,
        headers: Optional[Dict[str, bytes]] = None,
    ):
        """
        Schedule publish request
        :param value:
        :param stream:
        :param partition:
        :param key:
        :param headers:
        :return:
        """
        if not self.publish_queue:
            self._init_publisher()
        req = LiftBridgeClient.get_publish_request(
            value=value,
            stream=stream,
            partition=partition,
            key=key,
            headers=headers,
            auto_compress=bool(config.liftbridge.compression_method),
        )
        self.publish_queue.put(req)

    async def publisher_guard(self):
        while not self.publish_queue.to_shutdown:
            try:
                await self.publisher()
            except Exception as e:
                self.logger.error("Unhandled exception in liftbridge publisher: %s", e)

    async def publisher(self):
        async with LiftBridgeClient() as client:
            while not self.publish_queue.to_shutdown:
                req = await self.publish_queue.get(timeout=1)
                if not req:
                    continue  # Timeout or shutdown
                try:
                    await client.publish_sync(req, wait_for_stream=True)
                except LiftbridgeError as e:
                    self.logger.error("Failed to publish message: %s", e)
                    self.logger.error("Retry message")
                    await asyncio.sleep(1)
                    self.publish_queue.put(req, fifo=False)

    def get_topic_queue(self, topic: str) -> TopicQueue:
        q = self.topic_queues.get(topic)
        if q:
            return q
        # Create when necessary
        with self.topic_queue_lock:
            q = self.topic_queues.get(topic)
            if q:
                return q  # Created in concurrent task
            q = TopicQueue(topic)
            self.topic_queues[topic] = q
            self.loop.create_task(self.nsq_publisher_guard(q))
            return q

    async def nsq_publisher_guard(self, queue: TopicQueue):
        while not queue.to_shutdown:
            try:
                await self.nsq_publisher(queue)
            except Exception as e:
                self.logger.error("Unhandled exception in NSQ publisher, restarting: %s", e)

    async def nsq_publisher(self, queue: TopicQueue):
        """
        Publisher for NSQ topic

        :return:
        """
        topic = queue.topic
        self.logger.info("[nsq|%s] Starting NSQ publisher", topic)
        while not queue.to_shutdown or not queue.is_empty():
            # Message throttling. Wait and allow to collect more messages
            await queue.wait(timeout=10, rate=config.nsqd.topic_mpub_rate)
            # Get next batch up to `mpub_messages` messages or up to `mpub_size` size
            messages = list(
                queue.iter_get(
                    n=config.nsqd.mpub_messages,
                    size=config.nsqd.mpub_size,
                    total_overhead=4,
                    message_overhead=4,
                )
            )
            if not messages:
                continue
            try:
                self.logger.debug("[nsq|%s] Publishing %d messages", topic, len(messages))
                await mpub(topic, messages, dcs=self.dcs)
            except NSQPubError:
                if queue.to_shutdown:
                    self.logger.debug(
                        "[nsq|%s] Failed to publish during shutdown. Dropping messages", topic
                    )
                else:
                    # Return to queue
                    self.logger.info(
                        "[nsq|%s] Failed to publish. %d messages returned to queue",
                        topic,
                        len(messages),
                    )
                    queue.return_messages(messages)
            del messages  # Release memory
        self.logger.info("[nsq|%s] Stopping NSQ publisher", topic)
        # Queue is shut down and empty, notify
        queue.notify_shutdown()

    async def shutdown_executors(self):
        if self.executors:
            self.logger.info("Shutting down executors")
            for x in self.executors:
                try:
                    self.logger.info("Shutting down %s", x)
                    await self.executors[x].shutdown()
                except asyncio.TimeoutError:
                    self.logger.info("Timed out when shutting down %s", x)

    async def shutdown_subscriptions(self):
        self.logger.info("Shutting down subscriptions")
        self.subscriber_shutdown_waiter = asyncio.Event()
        try:
            await asyncio.wait_for(self.subscriber_shutdown_waiter.wait(), 10)
        except asyncio.TimeoutError:
            self.logger.info(
                "Timed out when shutting down subscriptions. Some message may be still processing"
            )

    async def shutdown_publisher(self):
        if self.publish_queue:
            r = await self.publish_queue.drain(5.0)
            if not r:
                self.logger.info(
                    "Unclean shutdown of liftbridge queue. Up to %d messages may be lost",
                    self.publish_queue.qsize(),
                )
            self.publish_queue.shutdown()

    async def shutdown_topic_queues(self):
        # Issue shutdown
        with self.topic_queue_lock:
            has_topics = bool(self.topic_queues)
            if has_topics:
                self.logger.info("Shutting down topic queues")
            for topic in self.topic_queues:
                self.topic_queues[topic].shutdown()
        # Wait for shutdown
        while has_topics:
            with self.topic_queue_lock:
                topic = next(iter(self.topic_queues.keys()))
                queue = self.topic_queues[topic]
                del self.topic_queues[topic]
                has_topics = bool(self.topic_queues)
            try:
                self.logger.info("Waiting shutdown of topic queue %s", topic)
                await queue.wait_for_shutdown(5.0)
            except asyncio.TimeoutError:
                self.logger.info("Failed to shutdown topic queue %s: Timed out", topic)

    def pub(self, topic, data, raw=False):
        """
        Publish message to topic
        :param topic: Topic name
        :param data: Message to send. Message will be automatically
          converted to JSON if *raw* is False, or passed as-is
          otherwise
        :param raw: True - pass message as-is, False - convert to JSON
        """
        q = self.get_topic_queue(topic)
        if raw:
            q.put(data)
        else:
            for chunk in q.iter_encode_chunks(data):
                q.put(chunk)

    def mpub(self, topic, messages):
        """
        Publish multiple messages to topic
        """
        q = self.get_topic_queue(topic)
        for m in messages:
            for chunk in q.iter_encode_chunks(m):
                q.put(chunk)

    def get_executor(self, name: str) -> ThreadPoolExecutor:
        """
        Return or start named executor
        """
        executor = self.executors.get(name)
        if not executor:
            xt = "%s.%s_threads" % (self.name, name)
            max_threads = config.get_parameter(xt)
            self.logger.info(
                "Starting threadpool executor %s (up to %d threads)", name, max_threads
            )
            executor = ThreadPoolExecutor(max_threads, name=name)
            self.executors[name] = executor
        return executor

    def run_in_executor(
        self, name: str, fn: Callable[[Any], T], *args: Any, **kwargs: Any
    ) -> asyncio.Future:
        executor = self.get_executor(name)
        return executor.submit(fn, *args, **kwargs)

    @staticmethod
    def _iter_metrics_raw_chunks(metrics: List[Dict[str, Any]]) -> Iterable[bytes]:
        r: List[bytes] = []
        size = 0
        for mi in metrics:
            jm = orjson.dumps(mi)
            ljm = len(jm)
            if size + ljm + 1 >= config.liftbridge.max_message_size:
                yield b"\n".join(r)
                r = []
                size = 0
            r.append(jm)
            if size:
                size += 1 + ljm
            else:
                size += ljm
        if r:
            yield b"\n".join(r)

    def register_metrics(
        self, table: str, metrics: List[Dict[str, Any]], key: Optional[int] = None
    ):
        """
        Send collected metrics to `table`

        Register metrics to send in non-clustered configuration.
        Must be used via register_metrics only

        :param fields: Table name
        :param metrics: List of dicts containing metrics records
        :param key: Sharding key, None for round-robin distribution
        :return:
        """
        if key is None:
            with self.metrics_key_lock:
                key = self.metrics_key_seq
                self.metrics_key_seq += 1
        for chunk in self._iter_metrics_raw_chunks(metrics):
            self.publish(
                chunk,
                stream=f"ch.{table}",
                partition=key % self.n_metrics_partitions,
            )

    def start_telemetry_callback(self) -> None:
        """
        Run telemetry callback
        :return:
        """
        self.telemetry_callback = PeriodicCallback(self.send_telemetry, 250)
        self.telemetry_callback.start()

    async def send_telemetry(self):
        """
        Publish telemetry data

        :return:
        """
        spans = get_spans()
        if spans:
            self.register_metrics("span", [span_to_dict(s) for s in spans])

    def get_leader_lock_name(self):
        if self.leader_lock_name:
            return self.leader_lock_name % {"pool": config.pool}
        return None

    def get_backend_weight(self):
        """
        Return backend weight for weighted load balancers
        (i.e. traefik).
        Return None for default weight
        :return:
        """
        return None

    def get_backend_limit(self):
        """
        Return backend connection limit for load balancers
        (i.e. traefik)
        Return None for no limits
        :return:
        """
        return None

    def is_valid_health_check(self, service_id):
        """
        Check received service id matches own service id
        :param service_id:
        :return:
        """
        return not (
            self.dcs
            and self.dcs.health_check_service_id
            and self.dcs.health_check_service_id != service_id
        )

    def get_health_status(self):
        """
        Check service health to be reported to service registry
        :return: (http code, message)
        """
        if self.dcs and self.require_dcs_health:
            # DCS is initialized
            return self.dcs.get_status()
        return 200, "OK"

    def uptime(self):
        if not self.startup_ts:
            return 0
        return perf_counter() - self.startup_ts

    async def get_stream_partitions(self, stream: str) -> int:
        """

        :param stream:
        :return:
        """
        async with LiftBridgeClient() as client:
            while True:
                meta = await client.fetch_metadata()
                if meta.metadata:
                    for stream_meta in meta.metadata:
                        if stream_meta.name == stream:
                            if stream_meta.partitions:
                                return len(stream_meta.partitions)
                            break
                # Cluster election in progress or cluster is misconfigured
                self.logger.info("Stream '%s' has no active partitions. Waiting" % stream)
                await asyncio.sleep(1)
Exemple #8
0
class ConsulDCS(DCSBase):
    """
    Consul-based DCS

    URL format:
    consul://<address>[:<port>]/<kv root>?token=<token>&check_interval=<...>&check_timeout=<...>&release_after=<...>
    """

    DEFAULT_CONSUL_HOST = config.consul.host
    DEFAULT_CONSUL_PORT = config.consul.port
    DEFAULT_CONSUL_CHECK_INTERVAL = config.consul.check_interval
    DEFAULT_CONSUL_CHECK_TIMEOUT = config.consul.connect_timeout
    DEFAULT_CONSUL_RELEASE = "".join([str(config.consul.release), "s"])
    DEFAULT_CONSUL_SESSION_TTL = config.consul.session_ttl
    DEFAULT_CONSUL_LOCK_DELAY = config.consul.lock_delay
    DEFAULT_CONSUL_RETRY_TIMEOUT = config.consul.retry_timeout
    DEFAULT_CONSUL_KEEPALIVE_ATTEMPTS = config.consul.keepalive_attempts
    EMPTY_HOLDER = ""

    resolver_cls = ConsulResolver

    def __init__(self, runner, url):
        self.name = None
        self.consul_host = self.DEFAULT_CONSUL_HOST
        self.consul_port = self.DEFAULT_CONSUL_PORT
        self.consul_prefix = "/"
        self.consul_token = config.consul.token
        self.check_interval = self.DEFAULT_CONSUL_CHECK_INTERVAL
        self.check_timeout = self.DEFAULT_CONSUL_CHECK_TIMEOUT
        self.release_after = self.DEFAULT_CONSUL_RELEASE
        self.keepalive_attempts = self.DEFAULT_CONSUL_KEEPALIVE_ATTEMPTS
        self.svc_name = None
        self.svc_address = None
        self.svc_port = None
        self.svc_check_url = None
        self.svc_id = None
        self.session = None
        self.slot_number = None
        self.total_slots = None
        super().__init__(runner, url)
        self.consul = ConsulClient(host=self.consul_host,
                                   port=self.consul_port,
                                   token=self.consul_token)
        self.session = None
        self.keep_alive_task = None
        self.service_watchers = {}
        self.in_keep_alive = False

    def parse_url(self, u):
        if ":" in u.netloc:
            self.consul_host, port = u.netloc.rsplit(":", 1)
            self.consul_port = int(port)
        else:
            self.consul_host = u.netloc
        self.consul_prefix = u.path[1:]
        if self.consul_prefix.endswith("/"):
            self.consul_prefix = self.consul_prefix[:-1]
        for q in u.query.split("&"):
            if not q or "=" not in q:
                continue
            k, v = q.split("=", 1)
            v = unquote(v)
            if k == "token":
                self.consul_token = v
            elif k == "check_interval":
                self.check_interval = int(v)
            elif k == "check_timeout":
                self.check_timeout = int(v)
            elif k == "release_after":
                self.release_after = v

    async def create_session(self):
        """
        Create consul session
        :return:
        """
        self.logger.info("Creating session")
        checks = ["serfHealth"]
        while True:
            try:
                self.session = await self.consul.session.create(
                    name=self.name,
                    checks=checks,
                    behavior="delete",
                    lock_delay=self.DEFAULT_CONSUL_LOCK_DELAY,
                    ttl=self.DEFAULT_CONSUL_SESSION_TTL,
                )
                break
            except ConsulRepeatableErrors:
                await asyncio.sleep(self.DEFAULT_CONSUL_RETRY_TIMEOUT)
                continue
        self.logger.info("Session id: %s", self.session)
        self.keep_alive_task = PeriodicCallback(
            self.keep_alive, self.DEFAULT_CONSUL_SESSION_TTL * 1000 / 2)
        self.keep_alive_task.start()

    async def destroy_session(self):
        if self.session:
            self.logger.info("Destroying session %s", self.session)
            if self.keep_alive_task:
                self.keep_alive_task.stop()
                self.keep_alive_task = None
            try:
                await self.consul.session.destroy(self.session)
            except ConsulRepeatableErrors:
                pass  # Ignore consul errors
            self.session = None

    async def register(self,
                       name,
                       address,
                       port,
                       pool=None,
                       lock=None,
                       tags=None):
        if pool:
            name = "%s-%s" % (name, pool)
        self.name = name
        if lock:
            await self.acquire_lock(lock)
        svc_id = self.session or str("svc-%s" % uuid.uuid4())
        tags = tags[:] if tags else []
        tags += [svc_id]
        self.svc_check_url = "http://%s:%s/health/?service=%s" % (address,
                                                                  port, svc_id)
        self.health_check_service_id = svc_id
        if config.features.consul_healthchecks:
            checks = consul.Check.http(self.svc_check_url, self.check_interval,
                                       "%ds" % self.check_timeout)
            checks["DeregisterCriticalServiceAfter"] = self.release_after
        else:
            checks = []
        if config.features.service_registration:
            while True:
                self.logger.info("Registering service %s: %s:%s (id=%s)", name,
                                 address, port, svc_id)
                try:
                    r = await self.consul.agent.service.register(
                        name=name,
                        service_id=svc_id,
                        address=address,
                        port=port,
                        tags=tags,
                        check=checks,
                    )
                except ConsulRepeatableErrors as e:
                    metrics["error", ("type", "cant_register_consul")] += 1
                    self.logger.info("Cannot register service %s: %s", name, e)
                    await asyncio.sleep(self.DEFAULT_CONSUL_RETRY_TIMEOUT)
                    continue
                if r:
                    self.svc_id = svc_id
                break
            return r
        else:
            return True

    async def deregister(self):
        if self.session:
            try:
                await self.destroy_session()
            except ConsulRepeatableErrors:
                metrics["error",
                        ("type", "cant_destroy_consul_session_soft")] += 1
            except Exception as e:
                metrics["error", ("type", "cant_destroy_consul_session")] += 1
                self.logger.error("Cannot destroy session: %s", e)
        if self.svc_id and config.features.service_registration:
            try:
                await self.consul.agent.service.deregister(self.svc_id)
            except ConsulRepeatableErrors:
                metrics["error", ("type", "cant_deregister_consul_soft")] += 1
            except Exception as e:
                metrics["error", ("type", "cant_deregister_consul")] += 1
                self.logger.error("Cannot deregister service: %s", e)
            self.svc_id = None

    async def keep_alive(self):
        metrics["dcs_consul_keepalives"] += 1
        if self.in_keep_alive:
            metrics["error", ("type", "dcs_consul_overlapped_keepalives")] += 1
            return
        self.in_keep_alive = True
        try:
            if self.session:
                touched = False
                for n in range(self.keepalive_attempts):
                    try:
                        await self.consul.session.renew(self.session)
                        self.logger.debug("Session renewed")
                        touched = True
                        break
                    except consul.base.NotFound as e:
                        self.logger.warning(
                            "Session lost by: '%s'. Forcing quit", e)
                        break
                    except ConsulRepeatableErrors as e:
                        self.logger.warning(
                            "Cannot refresh session due to ignorable error: %s",
                            e)
                        metrics["error",
                                ("type", "dcs_consul_keepalive_retries")] += 1
                        await asyncio.sleep(self.DEFAULT_CONSUL_RETRY_TIMEOUT)
                if not touched:
                    self.logger.critical("Cannot refresh session, stopping")
                    if self.keep_alive_task:
                        self.keep_alive_task.stop()
                        self.keep_alive_task = None
                    self.kill()
            elif self.keep_alive_task:
                self.keep_alive_task.stop()
                self.keep_alive_task = None
        finally:
            self.in_keep_alive = False

    def get_lock_path(self, lock):
        return self.consul_prefix + "/locks/" + lock

    async def acquire_lock(self, name):
        if not self.session:
            await self.create_session()
        key = self.get_lock_path(name)
        index = None
        while True:
            self.logger.info("Acquiring lock: %s", key)
            try:
                status = await self.consul.kv.put(key=key,
                                                  value=self.session,
                                                  acquire=self.session,
                                                  token=self.consul_token)
                if status:
                    break
                else:
                    metrics["error",
                            ("type", "dcs_consul_failed_get_lock")] += 1
                    self.logger.info("Failed to acquire lock")
                    await asyncio.sleep(self.DEFAULT_CONSUL_RETRY_TIMEOUT)
            except ConsulRepeatableErrors:
                await asyncio.sleep(self.DEFAULT_CONSUL_RETRY_TIMEOUT)
                continue
            # Waiting for lock release
            while True:
                try:
                    index, data = await self.consul.kv.get(
                        key=key, index=index, token=self.consul_token)
                    if not data:
                        index = None  # Key has been deleted
                        await asyncio.sleep(self.DEFAULT_CONSUL_LOCK_DELAY *
                                            (0.5 + random.random()))
                    break
                except ConsulRepeatableErrors:
                    await asyncio.sleep(self.DEFAULT_CONSUL_RETRY_TIMEOUT)
        self.logger.info("Lock acquired")

    async def get_slot_limit(self, name) -> Optional[int]:
        """
        Return the current limit for given slot
        :param name:
        :return:
        """
        manifest_path = "%s/slots/%s/manifest" % (self.consul_prefix, name)
        while True:
            self.logger.info("Attempting to get slot")
            # Non-blocking for a first time
            # Block until change every next try
            try:
                _, cv = await self.consul.kv.get(key=manifest_path, index=0)
                if not cv:
                    return 0
                return int(orjson.loads(cv["Value"]).get("Limit", 0))
            except ConsulRepeatableErrors:
                await asyncio.sleep(self.DEFAULT_CONSUL_RETRY_TIMEOUT)
                continue

    async def acquire_slot(self, name, limit):
        """
        Acquire shard slot
        :param name: <service name>-<pool>
        :param limit: Configured limit
        :return: (slot number, number of instances)
        """
        if not self.session:
            await self.create_session()
        if self.total_slots is not None:
            return self.slot_number, self.total_slots
        prefix = "%s/slots/%s" % (self.consul_prefix, name)
        contender_path = "%s/%s" % (prefix, self.session)
        contender_info = self.session
        manifest_path = "%s/manifest" % prefix
        self.logger.info("Writing contender slot info into %s", contender_path)
        while True:
            try:
                status = await self.consul.kv.put(
                    key=contender_path,
                    value=contender_info,
                    acquire=self.session,
                    token=self.consul_token,
                )
                if status:
                    break
                else:
                    metrics["error",
                            ("type", "dcs_consul_failed_get_slot")] += 1
                    self.logger.info("Failed to write contender slot info")
                    await asyncio.sleep(self.DEFAULT_CONSUL_RETRY_TIMEOUT)
            except ConsulRepeatableErrors:
                await asyncio.sleep(self.DEFAULT_CONSUL_RETRY_TIMEOUT)
        index = 0
        cas = 0
        while True:
            self.logger.info("Attempting to get slot")
            seen_sessions = set()
            dead_contenders = set()
            manifest = None
            # Non-blocking for a first time
            # Block until change every next try
            try:
                index, cv = await self.consul.kv.get(key=prefix,
                                                     index=index,
                                                     recurse=True)
            except ConsulRepeatableErrors:
                await asyncio.sleep(self.DEFAULT_CONSUL_RETRY_TIMEOUT)
                continue
            for e in cv:
                if e["Key"] == manifest_path:
                    cas = e["ModifyIndex"]
                    # @todo: Handle errors
                    manifest = orjson.loads(e["Value"])
                else:
                    if "Session" in e:
                        seen_sessions.add(e["Session"])
                    else:
                        dead_contenders.add(e["Key"])
            if manifest:
                total_slots = int(manifest.get("Limit", 0))
                holders = [
                    h if h in seen_sessions else self.EMPTY_HOLDER
                    for h in manifest.get("Holders", [])
                ]
            else:
                self.logger.info("Initializing manifest")
                total_slots = limit
                holders = []
            # Try to allocate slot
            if len(holders) < total_slots:
                # Available slots from the end
                slot_number = len(holders)
                holders += [self.session]
            else:
                # Try to reclaim slots in the middle
                try:
                    slot_number = holders.index(self.EMPTY_HOLDER)
                    holders[slot_number] = self.session
                except ValueError:
                    self.logger.info("All slots a busy, waiting")
                    continue
            # Update manifest
            self.logger.info("Attempting to acquire slot %s/%s", slot_number,
                             total_slots)
            try:
                r = await self.consul.kv.put(
                    key=manifest_path,
                    value=smart_text(
                        orjson.dumps({
                            "Limit": total_slots,
                            "Holders": holders
                        },
                                     option=orjson.OPT_INDENT_2)),
                    cas=cas,
                )
            except ConsulRepeatableErrors as e:
                self.logger.info("Cannot acquire slot: %s", e)
                continue
            if r:
                self.logger.info("Acquired slot %s/%s", slot_number,
                                 total_slots)
                self.slot_number = slot_number
                self.total_slots = total_slots
                return slot_number, total_slots
            self.logger.info("Cannot acquire slot: CAS changed, retry")

    async def resolve_near(self,
                           name,
                           hint=None,
                           wait=True,
                           timeout=None,
                           full_result=False,
                           critical=False):
        """
        Synchronous call to resolve nearby service
        Commonly used for external services like databases
        :param name: Service name
        :param wait:
        :param timeout:
        :param full_result:
        :param hint:
        :param critical:
        :return: address:port
        """
        self.logger.info("Resolve near service %s", name)
        index = 0
        while True:
            try:
                index, services = await self.consul.health.service(
                    service=name,
                    index=index,
                    near="_agent",
                    token=self.consul_token,
                    passing=True)
            except ConsulRepeatableErrors as e:
                metrics["error",
                        ("type", "dcs_consul_failed_resolve_near")] += 1
                self.logger.info("Consul error: %s", e)
                if critical:
                    metrics["error",
                            ("type",
                             "dcs_consul_failed_resolve_critical_near")] += 1
                    self.set_faulty_status("Consul error: %s" % e)
                time.sleep(CONSUL_NEAR_RETRY_TIMEOUT)
                continue
            if not services and wait:
                metrics["error",
                        ("type",
                         "dcs_consul_no_active_service %s" % name)] += 1
                self.logger.info("No active service %s. Waiting", name)
                if critical:
                    metrics["error",
                            ("type",
                             "dcs_consul_no_active_critical_service %s" %
                             name)] += 1
                    self.set_faulty_status("No active service %s. Waiting" %
                                           name)
                time.sleep(CONSUL_NEAR_RETRY_TIMEOUT)
                continue
            r = []
            for svc in services:
                r += [
                    "%s:%s" % (
                        str(svc["Service"]["Address"]
                            or svc["Node"]["Address"]),
                        str(svc["Service"]["Port"]),
                    )
                ]
                if not full_result:
                    break
            self.logger.info("Resolved near service %s to %s", name, r)
            if critical:
                self.clear_faulty_status()
            return r
Exemple #9
0
class DCSBase(object):
    # Resolver class
    resolver_cls = None
    # HTTP code to be returned by /health endpoint when service is healthy
    HEALTH_OK_HTTP_CODE = 200
    # HTTP code to be returned by /health endpoint when service is unhealthy
    # and must be temporary removed from resolver
    HEALTH_FAILED_HTTP_CODE = 429

    def __init__(self, runner, url):
        self.runner = runner
        self.logger = logging.getLogger(__name__)
        self.url = url
        self.parse_url(urlparse(url))
        # service -> resolver instances
        self.resolvers = {}
        self.resolvers_lock = threading.Lock()
        self.resolver_expiration_task = None
        self.health_check_service_id = None
        self.status = True
        self.status_message = ""
        self.thread_id = None

    def parse_url(self, u):
        pass

    async def start(self):
        """
        Start all pending tasks
        :return:
        """
        self.thread_id = threading.get_ident()
        self.resolver_expiration_task = PeriodicCallback(self.expire_resolvers, 10000)
        self.resolver_expiration_task.start()

    def stop(self):
        """
        Stop all pending tasks
        :return:
        """
        if self.resolver_expiration_task:
            self.resolver_expiration_task.stop()
            self.resolver_expiration_task = None
        # Stop all resolvers
        with self.resolvers_lock:
            for svc in self.resolvers:
                r = self.resolvers[svc]
                self.logger.info("Stopping resolver for service %s", svc)
                r.stop()
            self.resolvers = {}

    async def register(self, name, address, port, pool=None, lock=None, tags=None):
        """
        Register service
        :param name:
        :param address:
        :param port:
        :param pool:
        :param lock:
        :param tags: List of extra tags
        :return:
        """
        raise NotImplementedError()

    def kill(self):
        self.logger.info("Shooting self with SIGTERM")
        os.kill(os.getpid(), signal.SIGTERM)

    async def get_slot_limit(self, name) -> Optional[int]:
        """
        Return the current limit for given slot
        :param name:
        :return:
        """
        raise NotImplementedError()

    async def acquire_slot(self, name, limit):
        """
        Acquire shard slot
        :param name: <service name>-<pool>
        :param limit: Configured limit
        :return: (slot number, number of instances)
        """
        raise NotImplementedError()

    async def get_resolver(self, name, critical=False, near=False, track=True):
        def run_resolver(res):
            loop = asyncio.get_running_loop()
            loop.call_soon(loop.create_task, res.start())

        if track:
            with self.resolvers_lock:
                resolver = self.resolvers.get((name, critical, near))
                if not resolver:
                    self.logger.info("Running resolver for service %s", name)
                    resolver = self.resolver_cls(self, name, critical=critical, near=near)
                    self.resolvers[name, critical, near] = resolver
                    run_resolver(resolver)
        else:
            # One-time resolver
            resolver = self.resolver_cls(self, name, critical=critical, near=near, track=False)
            run_resolver(resolver)
        return resolver

    async def resolve(
        self,
        name,
        hint=None,
        wait=True,
        timeout=None,
        full_result=False,
        critical=False,
        near=False,
        track=True,
    ):
        async def wrap():
            resolver = await self.get_resolver(name, critical=critical, near=near, track=track)
            r = await resolver.resolve(
                hint=hint, wait=wait, timeout=timeout, full_result=full_result
            )
            return r

        return await self.runner.trampoline(wrap())

    async def expire_resolvers(self):
        with self.resolvers_lock:
            for svc in list(self.resolvers):
                r = self.resolvers[svc]
                if r.is_expired():
                    self.logger.info("Stopping expired resolver for service %s", svc)
                    r.stop()
                    del self.resolvers[svc]

    def resolve_sync(self, name, hint=None, wait=True, timeout=None, full_result=False):
        """
        Returns *hint* when service is active or new service
        instance,
        :param name:
        :param hint:
        :param full_result:
        :return:
        """

        async def _resolve():
            r = await self.resolve(
                name, hint=hint, wait=wait, timeout=timeout, full_result=full_result
            )
            return r

        return run_sync(_resolve)

    async def resolve_near(
        self, name, hint=None, wait=True, timeout=None, full_result=False, critical=False
    ):
        """
        Synchronous call to resolve nearby service
        Commonly used for external services like databases
        :param name: Service name
        :return: address:port
        """
        raise NotImplementedError()

    def get_status(self):
        if self.status:
            return self.HEALTH_OK_HTTP_CODE, "OK"
        else:
            return self.HEALTH_FAILED_HTTP_CODE, self.status_message

    def set_faulty_status(self, message):
        if self.status or self.status_message != message:
            self.logger.info("Set faulty status to: %s", message)
            self.status = False
            self.status_message = message

    def clear_faulty_status(self):
        if not self.status:
            self.logger.info("Clearing faulty status")
            self.status = True
            self.status_message = ""
Exemple #10
0
class SyslogCollectorService(TornadoService):
    name = "syslogcollector"
    pooled = True
    process_name = "noc-%(name).10s-%(pool).5s"

    def __init__(self):
        super().__init__()
        self.mappings_callback = None
        self.report_invalid_callback = None
        self.source_configs = {}  # id -> SourceConfig
        self.address_configs = {}  # address -> SourceConfig
        self.invalid_sources = defaultdict(int)  # ip -> count
        self.pool_partitions: Dict[str, int] = {}

    async def on_activate(self):
        # Listen sockets
        server = SyslogServer(service=self)
        for addr, port in server.iter_listen(config.syslogcollector.listen):
            self.logger.info("Starting syslog server at %s:%s", addr, port)
            try:
                server.listen(port, addr)
            except OSError as e:
                metrics["error", ("type", "socket_listen_error")] += 1
                self.logger.error("Failed to start syslog server at %s:%s: %s",
                                  addr, port, e)
        server.start()
        # Report invalid sources every 60 seconds
        self.logger.info("Stating invalid sources reporting task")
        self.report_invalid_callback = PeriodicCallback(
            self.report_invalid_sources, 60000)
        self.report_invalid_callback.start()
        # Start tracking changes
        asyncio.get_running_loop().create_task(self.get_object_mappings())

    async def get_pool_partitions(self, pool: str) -> int:
        parts = self.pool_partitions.get(pool)
        if not parts:
            parts = await self.get_stream_partitions("events.%s" % pool)
            self.pool_partitions[pool] = parts
        return parts

    def lookup_config(self, address: str) -> Optional[SourceConfig]:
        """
        Returns object id for given address or None when
        unknown source
        """
        cfg = self.address_configs.get(address)
        if cfg:
            return cfg
        # Register invalid event source
        if self.address_configs:
            self.invalid_sources[address] += 1
        metrics["error", ("type", "object_not_found")] += 1
        return None

    def register_message(self, cfg: SourceConfig, timestamp: int, message: str,
                         facility: int, severity: int) -> None:
        """
        Spool message to be sent
        """
        if cfg.process_events:
            # Send to classifier
            metrics["events_out"] += 1
            self.publish(
                orjson.dumps({
                    "ts": timestamp,
                    "object": cfg.id,
                    "data": {
                        "source": "syslog",
                        "collector": config.pool,
                        "message": message
                    },
                }),
                stream=cfg.stream,
                partition=cfg.partition,
            )
        if cfg.archive_events and cfg.bi_id:
            # Archive message
            metrics["events_archived"] += 1
            now = datetime.datetime.now()
            ts = now.strftime("%Y-%m-%d %H:%M:%S")
            date = ts.split(" ")[0]
            self.register_metrics(
                "syslog",
                [{
                    "date": date,
                    "ts": ts,
                    "managed_object": cfg.bi_id,
                    "facility": facility,
                    "severity": severity,
                    "message": message,
                }],
            )

    async def get_object_mappings(self):
        """
        Subscribe and track datastream changes
        """
        # Register RPC aliases
        client = SysologDataStreamClient("cfgsyslog", service=self)
        # Track stream changes
        while True:
            self.logger.info("Starting to track object mappings")
            try:
                await client.query(
                    limit=config.syslogcollector.ds_limit,
                    filters=["pool(%s)" % config.pool],
                    block=1,
                )
            except NOCError as e:
                self.logger.info("Failed to get object mappings: %s", e)
                await asyncio.sleep(1)

    async def report_invalid_sources(self):
        """
        Report invalid event sources
        """
        if not self.invalid_sources:
            return
        total = sum(self.invalid_sources[s] for s in self.invalid_sources)
        self.logger.info(
            "Dropping %d messages with invalid sources: %s",
            total,
            ", ".join("%s: %s" % (s, self.invalid_sources[s])
                      for s in self.invalid_sources),
        )
        self.invalid_sources = defaultdict(int)

    async def update_source(self, data):
        # Get old config
        old_cfg = self.source_configs.get(data["id"])
        if old_cfg:
            old_addresses = set(old_cfg.addresses)
        else:
            old_addresses = set()
        # Get pool and sharding information
        fm_pool = data.get("fm_pool", None) or config.pool
        num_partitions = await self.get_pool_partitions(fm_pool)
        # Build new config
        cfg = SourceConfig(
            id=data["id"],
            addresses=tuple(data["addresses"]),
            bi_id=data.get("bi_id"),  # For backward compatibility
            process_events=data.get("process_events",
                                    True),  # For backward compatibility
            archive_events=data.get("archive_events", False),
            stream="events.%s" % fm_pool,
            partition=int(data["id"]) % num_partitions,
        )
        new_addresses = set(cfg.addresses)
        # Add new addresses, update remaining
        for addr in new_addresses:
            self.address_configs[addr] = cfg
        # Revoke stale addresses
        for addr in old_addresses - new_addresses:
            del self.address_configs[addr]
        # Update configs
        self.source_configs[data["id"]] = cfg
        # Update metrics
        metrics["sources_changed"] += 1

    async def delete_source(self, id):
        cfg = self.source_configs.get(id)
        if not cfg:
            return
        for addr in cfg.addresses:
            del self.address_configs[addr]
        del self.source_configs[id]
        metrics["sources_deleted"] += 1
Exemple #11
0
class TrapCollectorService(TornadoService):
    name = "trapcollector"
    leader_group_name = "trapcollector-%(dc)s-%(node)s"
    pooled = True
    process_name = "noc-%(name).10s-%(pool).5s"

    def __init__(self):
        super().__init__()
        self.mappings_callback = None
        self.report_invalid_callback = None
        self.source_configs: Dict[str, SourceConfig] = {}  # id -> SourceConfig
        self.address_configs = {}  # address -> SourceConfig
        self.invalid_sources = defaultdict(int)  # ip -> count
        self.pool_partitions: Dict[str, int] = {}

    async def on_activate(self):
        # Listen sockets
        server = TrapServer(service=self)
        for addr, port in server.iter_listen(config.trapcollector.listen):
            self.logger.info("Starting SNMP Trap server at %s:%s", addr, port)
            try:
                server.listen(port, addr)
            except OSError as e:
                metrics["error", ("type", "socket_listen_error")] += 1
                self.logger.error(
                    "Failed to start SNMP Trap server at %s:%s: %s", addr,
                    port, e)
        server.start()
        # Report invalid sources every 60 seconds
        self.logger.info("Stating invalid sources reporting task")
        self.report_invalid_callback = PeriodicCallback(
            self.report_invalid_sources, 60000)
        self.report_invalid_callback.start()
        # Start tracking changes
        asyncio.get_running_loop().create_task(self.get_object_mappings())

    async def get_pool_partitions(self, pool: str) -> int:
        parts = self.pool_partitions.get(pool)
        if not parts:
            parts = await self.get_stream_partitions("events.%s" % pool)
            self.pool_partitions[pool] = parts
        return parts

    def lookup_config(self, address: str) -> Optional[SourceConfig]:
        """
        Returns object config for given address or None when
        unknown source
        """
        cfg = self.address_configs.get(address)
        if cfg:
            return cfg
        # Register invalid event source
        if self.address_configs:
            self.invalid_sources[address] += 1
            metrics["error", ("type", "object_not_found")] += 1
        return None

    def register_message(self, cfg: SourceConfig, timestamp: int,
                         data: Dict[str, Any]):
        """
        Spool message to be sent
        """
        metrics["events_out"] += 1
        self.publish(
            orjson.dumps({
                "ts": timestamp,
                "object": cfg.id,
                "data": data,
            }),
            stream=cfg.stream,
            partition=cfg.partition,
        )

    async def get_object_mappings(self):
        """
        Coroutine to request object mappings
        """
        self.logger.info("Starting to track object mappings")
        client = TrapDataStreamClient("cfgtrap", service=self)
        # Track stream changes
        while True:
            try:
                await client.query(limit=config.trapcollector.ds_limit,
                                   filters=["pool(%s)" % config.pool],
                                   block=1)
            except NOCError as e:
                self.logger.info("Failed to get object mappings: %s", e)
                await asyncio.sleep(1)

    async def report_invalid_sources(self):
        """
        Report invalid event sources
        """
        if not self.invalid_sources:
            return
        total = sum(self.invalid_sources[s] for s in self.invalid_sources)
        self.logger.info(
            "Dropping %d messages with invalid sources: %s",
            total,
            ", ".join("%s: %s" % (s, self.invalid_sources[s])
                      for s in self.invalid_sources),
        )
        self.invalid_sources = defaultdict(int)

    async def update_source(self, data):
        # Get old config
        old_cfg = self.source_configs.get(data["id"])
        if old_cfg:
            old_addresses = set(old_cfg.addresses)
        else:
            old_addresses = set()
        # Get pool and sharding information
        fm_pool = data.get("fm_pool", None) or config.pool
        num_partitions = await self.get_pool_partitions(fm_pool)
        # Build new config
        cfg = SourceConfig(
            id=data["id"],
            addresses=tuple(data["addresses"]),
            stream="events.%s" % fm_pool,
            partition=int(data["id"]) % num_partitions,
        )
        new_addresses = set(cfg.addresses)
        # Add new addresses, update remaining
        for addr in new_addresses:
            self.address_configs[addr] = cfg
        # Revoke stale addresses
        for addr in old_addresses - new_addresses:
            del self.address_configs[addr]
        # Update configs
        self.source_configs[data["id"]] = cfg
        # Update metrics
        metrics["sources_changed"] += 1

    async def delete_source(self, id):
        cfg = self.source_configs.get(id)
        if not cfg:
            return
        for addr in cfg.addresses:
            del self.address_configs[addr]
        del self.source_configs[id]
        metrics["sources_deleted"] += 1