コード例 #1
0
class CLI(object):
    name = "cli"
    default_port = None
    iostream_class = None
    BUFFER_SIZE = config.activator.buffer_size
    MATCH_TAIL = 256
    # Buffer to check missed ECMA control characters
    MATCH_MISSED_CONTROL_TAIL = 8
    # Retries on immediate disconnect
    CONNECT_RETRIES = config.activator.connect_retries
    # Timeout after immediate disconnect
    CONNECT_TIMEOUT = config.activator.connect_timeout
    # compiled capabilities
    HAS_TCP_KEEPALIVE = hasattr(socket, "SO_KEEPALIVE")
    HAS_TCP_KEEPIDLE = hasattr(socket, "TCP_KEEPIDLE")
    HAS_TCP_KEEPINTVL = hasattr(socket, "TCP_KEEPINTVL")
    HAS_TCP_KEEPCNT = hasattr(socket, "TCP_KEEPCNT")
    HAS_TCP_NODELAY = hasattr(socket, "TCP_NODELAY")
    # Time until sending first keepalive probe
    KEEP_IDLE = 10
    # Keepalive packets interval
    KEEP_INTVL = 10
    # Terminate connection after N keepalive failures
    KEEP_CNT = 3
    SYNTAX_ERROR_CODE = "+@@@NOC:SYNTAXERROR@@@+"

    class InvalidPagerPattern(Exception):
        pass

    def __init__(self, script, tos=None):
        self.script = script
        self.profile = script.profile
        self.logger = PrefixLoggerAdapter(self.script.logger, self.name)
        self.iostream = None
        self.motd = ""
        self.ioloop = None
        self.command = None
        self.prompt_stack = []
        self.patterns = self.profile.patterns.copy()
        self.buffer = ""
        self.is_started = False
        self.result = None
        self.error = None
        self.pattern_table = None
        self.collected_data = []
        self.tos = tos
        self.current_timeout = None
        self.is_closed = False
        self.close_timeout = None
        self.close_timeout_lock = Lock()
        self.setup_complete = False
        self.to_raise_privileges = script.credentials.get(
            "raise_privileges", True)
        self.state = "start"
        # State retries
        self.super_password_retries = self.profile.cli_retries_super_password

    def close(self):
        self.script.close_current_session()
        self.close_iostream()
        if self.ioloop:
            self.logger.debug("Closing IOLoop")
            self.ioloop.close(all_fds=True)
            self.ioloop = None
        self.is_closed = True

    def close_iostream(self):
        if self.iostream:
            self.logger.debug("Closing IOStream")
            self.iostream.close()
            self.iostream = None

    def set_state(self, state):
        self.logger.debug("Changing state to <%s>", state)
        self.state = state

    def maybe_close(self):
        with self.close_timeout_lock:
            if not self.close_timeout:
                return  # Race with execute(), no need to close
            if self.ioloop:
                self.ioloop.remove_timeout(self.close_timeout)
            self.close_timeout = None
            self.close()

    def reset_close_timeout(self):
        with self.close_timeout_lock:
            if self.close_timeout:
                self.logger.debug("Removing close timeout")
                self.ioloop.remove_timeout(self.close_timeout)
                self.close_timeout = None

    def deferred_close(self, session_timeout):
        if self.is_closed or not self.iostream:
            return
        self.logger.debug("Setting close timeout to %ss", session_timeout)
        # Cannot call call_later directly due to
        # thread-safety problems
        # See tornado issue #1773
        tornado.ioloop.IOLoop.instance().add_callback(self._set_close_timeout,
                                                      session_timeout)

    def _set_close_timeout(self, session_timeout):
        """
        Wrapper to deal with IOLoop.add_timeout thread safety problem
        :param session_timeout:
        :return:
        """
        with self.close_timeout_lock:
            self.close_timeout = tornado.ioloop.IOLoop.instance().call_later(
                session_timeout, self.maybe_close)

    def create_iostream(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        if self.tos:
            self.logger.debug("Setting IP ToS to %d", self.tos)
            s.setsockopt(socket.IPPROTO_IP, socket.IP_TOS, self.tos)
        if self.HAS_TCP_NODELAY:
            self.logger.info("Setting TCP NODELAY")
            s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
        if self.HAS_TCP_KEEPALIVE:
            self.logger.info("Settings KEEPALIVE")
            s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
            if self.HAS_TCP_KEEPIDLE:
                self.logger.info("Setting TCP KEEPIDLE to %d", self.KEEP_IDLE)
                s.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE,
                             self.KEEP_IDLE)
            if self.HAS_TCP_KEEPINTVL:
                self.logger.info("Setting TCP KEEPINTVL to %d",
                                 self.KEEP_INTVL)
                s.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL,
                             self.KEEP_INTVL)
            if self.HAS_TCP_KEEPCNT:
                self.logger.info("Setting TCP KEEPCNT to %d", self.KEEP_CNT)
                s.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT, self.KEEP_CNT)
        return self.iostream_class(s, self)

    def set_timeout(self, timeout):
        if timeout:
            self.logger.debug("Setting timeout: %ss", timeout)
            self.current_timeout = datetime.timedelta(seconds=timeout)
        else:
            if self.current_timeout:
                self.logger.debug("Resetting timeouts")
            self.current_timeout = None

    def run_sync(self, func, *args, **kwargs):
        """
        Simplified implementation of IOLoop.run_sync
        to distinguish real TimeoutErrors from incomplete futures
        :param func:
        :param args:
        :param kwargs:
        :return:
        """
        future_cell = [None]

        def run():
            try:
                result = func(*args, **kwargs)
                if result is not None:
                    result = tornado.gen.convert_yielded(result)
                future_cell[0] = result
            except Exception:
                future_cell[0] = tornado.concurrent.TracebackFuture()
                future_cell[0].set_exc_info(sys.exc_info())
            self.ioloop.add_future(future_cell[0],
                                   lambda future: self.ioloop.stop())

        self.ioloop.add_callback(run)
        self.ioloop.start()
        if not future_cell[0].done():
            self.logger.info("Incomplete feature left. Restarting IOStream")
            self.close_iostream()
            # Retain cryptic message as is,
            # Mark feature as done
            future_cell[0].set_exception(
                tornado.gen.TimeoutError(
                    "Operation timed out after %s seconds" % None))
        return future_cell[0].result()

    def execute(
        self,
        cmd,
        obj_parser=None,
        cmd_next=None,
        cmd_stop=None,
        ignore_errors=False,
        allow_empty_response=True,
    ):
        self.reset_close_timeout()
        self.buffer = ""
        self.command = cmd
        self.error = None
        self.ignore_errors = ignore_errors
        self.allow_empty_response = allow_empty_response
        if not self.ioloop:
            self.logger.debug("Creating IOLoop")
            self.ioloop = tornado.ioloop.IOLoop()
        if obj_parser:
            parser = functools.partial(self.parse_object_stream, obj_parser,
                                       cmd_next, cmd_stop)
        else:
            parser = self.read_until_prompt
        with Span(server=self.script.credentials.get("address"),
                  service=self.name,
                  in_label=cmd) as s:
            self.run_sync(self.submit, parser)
            if self.error:
                if s:
                    s.error_text = str(self.error)
                raise self.error
            else:
                return self.result

    @tornado.gen.coroutine
    def submit(self, parser=None):
        # Create iostream and connect, when necessary
        if not self.iostream:
            self.iostream = self.create_iostream()
            address = (
                self.script.credentials.get("address"),
                self.script.credentials.get("cli_port", self.default_port),
            )
            self.logger.debug("Connecting %s", address)
            try:
                metrics["cli_connection", ("proto", self.name)] += 1
                yield self.iostream.connect(address)
                metrics["cli_connection_success", ("proto", self.name)] += 1
            except tornado.iostream.StreamClosedError:
                self.logger.debug("Connection refused")
                metrics["cli_connection_refused", ("proto", self.name)] += 1
                self.error = CLIConnectionRefused("Connection refused")
                raise tornado.gen.Return(None)
            self.logger.debug("Connected")
            yield self.iostream.startup()
        # Perform all necessary login procedures
        metrics["cli_commands", ("proto", self.name)] += 1
        if not self.is_started:
            yield self.on_start()
            self.motd = yield self.read_until_prompt()
            self.script.set_motd(self.motd)
            self.is_started = True
        # Send command
        # @todo: encode to object's encoding
        if self.profile.batch_send_multiline or self.profile.command_submit not in self.command:
            yield self.send(self.command)
        else:
            # Send multiline commands line-by-line
            for cmd in self.command.split(self.profile.command_submit):
                # Send line
                yield self.send(cmd + self.profile.command_submit)
                # @todo: Await response
        parser = parser or self.read_until_prompt
        self.result = yield parser()
        self.logger.debug("Command: %s\n%s", self.command.strip(), self.result)
        if (self.profile.rx_pattern_syntax_error and not self.ignore_errors
                and parser == self.read_until_prompt
                and (self.profile.rx_pattern_syntax_error.search(self.result)
                     or self.result == self.SYNTAX_ERROR_CODE)):
            error_text = self.result
            if self.profile.send_on_syntax_error and self.name != "beef_cli":
                self.allow_empty_response = True
                yield self.on_error_sequence(self.profile.send_on_syntax_error,
                                             self.command, error_text)
            self.error = self.script.CLISyntaxError(error_text)
            self.result = None
        raise tornado.gen.Return(self.result)

    def cleaned_input(self, s):
        """
        Clean up received input and wipe out control sequences
        and rogue chars
        """
        # Wipe out rogue chars
        if self.profile.rogue_chars:
            for rc in self.profile.rogue_chars:
                try:
                    s = rc.sub("", s)  # rc is compiled regular expression
                except AttributeError:
                    s = s.replace(rc, "")  # rc is a string
        # Clean control sequences
        return self.profile.cleaned_input(s)

    @tornado.gen.coroutine
    def send(self, cmd):
        # @todo: Apply encoding
        cmd = str(cmd)
        self.logger.debug("Send: %r", cmd)
        yield self.iostream.write(cmd)

    @tornado.gen.coroutine
    def read_until_prompt(self):
        connect_retries = self.CONNECT_RETRIES
        while True:
            try:
                metrics["cli_reads", ("proto", self.name)] += 1
                f = self.iostream.read_bytes(self.BUFFER_SIZE, partial=True)
                if self.current_timeout:
                    r = yield tornado.gen.with_timeout(self.current_timeout, f)
                else:
                    r = yield f
                if r == self.SYNTAX_ERROR_CODE:
                    metrics["cli_syntax_errors", ("proto", self.name)] += 1
                    raise tornado.gen.Return(self.SYNTAX_ERROR_CODE)
                metrics["cli_read_bytes", ("proto", self.name)] += len(r)
                if self.script.to_track:
                    self.script.push_cli_tracking(r, self.state)
            except tornado.iostream.StreamClosedError:
                # Check if remote end closes connection just
                # after connection established
                if not self.is_started and connect_retries:
                    self.logger.info(
                        "Connection reset. %d retries left. Waiting %d seconds",
                        connect_retries,
                        self.CONNECT_TIMEOUT,
                    )
                    while connect_retries:
                        yield tornado.gen.sleep(self.CONNECT_TIMEOUT)
                        connect_retries -= 1
                        self.iostream = self.create_iostream()
                        address = (
                            self.script.credentials.get("address"),
                            self.script.credentials.get(
                                "cli_port", self.default_port),
                        )
                        self.logger.debug("Connecting %s", address)
                        try:
                            yield self.iostream.connect(address)
                            yield self.iostream.startup()
                            break
                        except tornado.iostream.StreamClosedError:
                            if not connect_retries:
                                raise tornado.iostream.StreamClosedError()
                    continue
                else:
                    raise tornado.iostream.StreamClosedError()
            except tornado.gen.TimeoutError:
                self.logger.info("Timeout error")
                metrics["cli_timeouts", ("proto", self.name)] += 1
                # IOStream must be closed to prevent hanging read callbacks
                self.close_iostream()
                raise tornado.gen.TimeoutError("Timeout")
            self.logger.debug("Received: %r", r)
            # Clean input
            if self.buffer.find("\x1b", -self.MATCH_MISSED_CONTROL_TAIL) != -1:
                self.buffer = self.cleaned_input(self.buffer + r)
            else:
                self.buffer += self.cleaned_input(r)
            # Try to find matched pattern
            offset = max(0, len(self.buffer) - self.MATCH_TAIL)
            for rx, handler in six.iteritems(self.pattern_table):
                match = rx.search(self.buffer, offset)
                if match:
                    self.logger.debug("Match: %s", rx.pattern)
                    matched = self.buffer[:match.start()]
                    self.buffer = self.buffer[match.end():]
                    if isinstance(handler, tuple):
                        metrics["cli_state",
                                ("state", handler[0].__name__)] += 1
                        r = yield handler[0](matched, match, *handler[1:])
                    else:
                        metrics["cli_state", ("state", handler.__name__)] += 1
                        r = yield handler(matched, match)
                    if r is not None:
                        raise tornado.gen.Return(r)
                    else:
                        break  # This state is processed

    @tornado.gen.coroutine
    def parse_object_stream(self, parser=None, cmd_next=None, cmd_stop=None):
        """
        :param cmd:
        :param command_submit:
        :param parser: callable accepting buffer and returning
                       (key, data, rest) or None.
                       key - string with object distinguisher
                       data - dict containing attributes
                       rest -- unparsed rest of string
        :param cmd_next: Sequence to go to the next page
        :param cmd_stop: Sequence to stop
        :return:
        """
        self.logger.debug("Parsing object stream")
        objects = []
        seen = set()
        buffer = ""
        repeats = 0
        r_key = None
        stop_sent = False
        done = False
        while not done:
            r = yield self.iostream.read_bytes(self.BUFFER_SIZE, partial=True)
            if self.script.to_track:
                self.script.push_cli_tracking(r, self.state)
            self.logger.debug("Received: %r", r)
            buffer = self.cleaned_input(buffer + r)
            # Check for syntax error
            if (self.profile.rx_pattern_syntax_error and not self.ignore_errors
                    and self.profile.rx_pattern_syntax_error.search(
                        self.buffer)):
                error_text = self.buffer
                if self.profile.send_on_syntax_error:
                    yield self.on_error_sequence(
                        self.profile.send_on_syntax_error, self.command,
                        error_text)
                self.error = self.script.CLISyntaxError(error_text)
                break
            # Then check for operation error
            if (self.profile.rx_pattern_operation_error
                    and self.profile.rx_pattern_operation_error.search(
                        self.buffer)):
                self.error = self.script.CLIOperationError(self.buffer)
                break
            # Parse all possible objects
            while buffer:
                pr = parser(buffer)
                if not pr:
                    break  # No new objects
                key, obj, buffer = pr
                if key not in seen:
                    seen.add(key)
                    objects += [obj]
                    repeats = 0
                    r_key = None
                elif r_key:
                    if r_key == key:
                        repeats += 1
                        if repeats >= 3 and cmd_stop and not stop_sent:
                            # Stop loop at final page
                            # After 3 repeats
                            self.logger.debug("Stopping stream. Sending %r" %
                                              cmd_stop)
                            self.send(cmd_stop)
                            stop_sent = True
                else:
                    r_key = key
                    if cmd_next:
                        self.logger.debug("Next screen. Sending %r" % cmd_next)
                        self.send(cmd_next)
            # Check for prompt
            for rx, handler in six.iteritems(self.pattern_table):
                offset = max(0, len(buffer) - self.MATCH_TAIL)
                match = rx.search(buffer, offset)
                if match:
                    self.logger.debug("Match: %s", rx.pattern)
                    matched = buffer[:match.start()]
                    buffer = self.buffer[match.end():]
                    r = handler(matched, match)
                    if r is not None:
                        self.logger.debug("Prompt matched")
                        done = True
                        break
        raise tornado.gen.Return(objects)

    def send_pager_reply(self, data, match):
        """
        Send proper pager reply
        """
        pg = match.group(0)
        for p, c in self.patterns["more_patterns_commands"]:
            if p.search(pg):
                self.collected_data += [data]
                self.send(c)
                return
        raise self.InvalidPagerPattern(pg)

    def expect(self, patterns, timeout=None):
        """
        Send command if not none and set reply patterns
        """
        self.pattern_table = {}
        for pattern_name in patterns:
            rx = self.patterns.get(pattern_name)
            if not rx:
                continue
            self.pattern_table[rx] = patterns[pattern_name]
        self.set_timeout(timeout)

    @tornado.gen.coroutine
    def on_start(self, data=None, match=None):
        self.set_state("start")
        if self.profile.setup_sequence and not self.setup_complete:
            self.expect({"setup": self.on_setup_sequence},
                        self.profile.cli_timeout_setup)
        else:
            self.expect(
                {
                    "username": self.on_username,
                    "password": self.on_password,
                    "unprivileged_prompt": self.on_unprivileged_prompt,
                    "prompt": self.on_prompt,
                    "pager": self.send_pager_reply,
                },
                self.profile.cli_timeout_start,
            )

    @tornado.gen.coroutine
    def on_username(self, data, match):
        self.set_state("username")
        self.send((self.script.credentials.get("user", "") or "") +
                  (self.profile.username_submit or "\n"))
        self.expect(
            {
                "username": (self.on_failure, CLIAuthFailed),
                "password": self.on_password,
                "unprivileged_prompt": self.on_unprivileged_prompt,
                "prompt": self.on_prompt,
            },
            self.profile.cli_timeout_user,
        )

    @tornado.gen.coroutine
    def on_password(self, data, match):
        self.set_state("password")
        self.send((self.script.credentials.get("password", "") or "") +
                  (self.profile.password_submit or "\n"))
        self.expect(
            {
                "username": (self.on_failure, CLIAuthFailed),
                "password": (self.on_failure, CLIAuthFailed),
                "unprivileged_prompt": self.on_unprivileged_prompt,
                "super_password": self.on_super_password,
                "prompt": self.on_prompt,
                "pager": self.send_pager_reply,
            },
            self.profile.cli_timeout_password,
        )

    @tornado.gen.coroutine
    def on_unprivileged_prompt(self, data, match):
        self.set_state("unprivileged_prompt")
        if self.to_raise_privileges:
            # Start privilege raising sequence
            if not self.profile.command_super:
                self.on_failure(data, match, CLINoSuperCommand)
            self.send(self.profile.command_super +
                      (self.profile.command_submit or "\n"))
            # Do not remove `pager` section
            # It fixes this situation on Huawei MA5300:
            # xxx>enable
            # { <cr>|level-value<U><1,15> }:
            # xxx#
            self.expect(
                {
                    "username": self.on_super_username,
                    "password": self.on_super_password,
                    "prompt": self.on_prompt,
                    "pager": self.send_pager_reply,
                },
                self.profile.cli_timeout_super,
            )
        else:
            # Do not raise privileges
            # Use unprivileged prompt as primary prompt
            self.patterns["prompt"] = self.patterns["unprivileged_prompt"]
            return self.on_prompt(data, match)

    @tornado.gen.coroutine
    def on_failure(self, data, match, error_cls=None):
        self.set_state("failure")
        error_cls = error_cls or CLIError
        raise error_cls(self.buffer or data or None)

    @tornado.gen.coroutine
    def on_prompt(self, data, match):
        self.set_state("prompt")
        if not self.allow_empty_response:
            s_data = data.strip()
            if not s_data or s_data == self.command.strip():
                return None
        if not self.is_started:
            self.resolve_pattern_prompt(match)
        d = "".join(self.collected_data + [data])
        self.collected_data = []
        self.expect({"prompt": self.on_prompt, "pager": self.send_pager_reply})
        return d

    @tornado.gen.coroutine
    def on_super_username(self, data, match):
        self.set_state("super_username")
        self.send((self.script.credentials.get("user", "") or "") +
                  (self.profile.username_submit or "\n"))
        self.expect(
            {
                "username": (self.on_failure, CLILowPrivileges),
                "password": self.on_super_password,
                "unprivileged_prompt": self.on_unprivileged_prompt,
                "prompt": self.on_prompt,
                "pager": self.send_pager_reply,
            },
            self.profile.cli_timeout_user,
        )

    @tornado.gen.coroutine
    def on_super_password(self, data, match):
        self.set_state("super_password")
        self.send((self.script.credentials.get("super_password", "") or "") +
                  (self.profile.username_submit or "\n"))
        if self.super_password_retries > 1:
            unprivileged_handler = self.on_unprivileged_prompt
            self.super_password_retries -= 1
        else:
            unprivileged_handler = (self.on_failure, CLILowPrivileges)
        self.expect(
            {
                "prompt": self.on_prompt,
                "password": (self.on_failure, CLILowPrivileges),
                "super_password": (self.on_failure, CLILowPrivileges),
                "pager": self.send_pager_reply,
                "unprivileged_prompt": unprivileged_handler,
            },
            self.profile.cli_timeout_password,
        )

    @tornado.gen.coroutine
    def on_setup_sequence(self, data, match):
        self.set_state("setup")
        self.logger.debug("Performing setup sequence: %s",
                          self.profile.setup_sequence)
        lseq = len(self.profile.setup_sequence)
        for i, c in enumerate(self.profile.setup_sequence):
            if isinstance(c, six.integer_types) or isinstance(c, float):
                yield tornado.gen.sleep(c)
                continue
            cmd = c % self.script.credentials
            yield self.send(cmd)
            # Waiting for response and drop it
            if i < lseq - 1:
                resp = yield tornado.gen.with_timeout(
                    self.ioloop.time() + 30,
                    future=self.iostream.read_bytes(4096, partial=True),
                    io_loop=self.ioloop,
                )
                if self.script.to_track:
                    self.script.push_cli_tracking(resp, self.state)
                self.logger.debug("Receiving: %r", resp)
        self.logger.debug("Setup sequence complete")
        self.setup_complete = True
        yield self.on_start(data, match)

    def resolve_pattern_prompt(self, match):
        """
        Resolve adaptive pattern prompt
        """
        old_pattern_prompt = self.patterns["prompt"].pattern
        pattern_prompt = old_pattern_prompt
        sl = self.profile.can_strip_hostname_to
        for k, v in six.iteritems(match.groupdict()):
            if v:
                if k == "hostname" and sl and len(v) > sl:
                    ss = list(reversed(v[sl:]))
                    v = re.escape(v[:sl]) + reduce(
                        lambda x, y: "(?:%s%s)?" % (re.escape(y), x),
                        ss[1:],
                        "(?:%s)?" % re.escape(ss[0]),
                    )
                else:
                    v = re.escape(v)
                pattern_prompt = replace_re_group(pattern_prompt,
                                                  "(?P<%s>" % k, v)
                pattern_prompt = replace_re_group(pattern_prompt, "(?P=%s" % k,
                                                  v)
            else:
                self.logger.error("Invalid prompt pattern")
        if old_pattern_prompt != pattern_prompt:
            self.logger.debug("Refining pattern prompt to %r", pattern_prompt)
        self.patterns["prompt"] = re.compile(pattern_prompt,
                                             re.DOTALL | re.MULTILINE)

    def push_prompt_pattern(self, pattern):
        """
        Override prompt pattern
        """
        self.logger.debug("New prompt pattern: %s", pattern)
        self.prompt_stack += [self.patterns["prompt"]]
        self.patterns["prompt"] = re.compile(pattern, re.DOTALL | re.MULTILINE)
        self.pattern_table[self.patterns["prompt"]] = self.on_prompt

    def pop_prompt_pattern(self):
        """
        Restore prompt pattern
        """
        self.logger.debug("Restore prompt pattern")
        pattern = self.prompt_stack.pop(-1)
        self.patterns["prompt"] = pattern
        self.pattern_table[self.patterns["prompt"]] = self.on_prompt

    def get_motd(self):
        """
        Return collected message of the day
        """
        return self.motd

    def set_script(self, script):
        self.script = script
        self.logger = PrefixLoggerAdapter(self.script.logger, self.name)
        self.reset_close_timeout()
        if self.motd:
            self.script.set_motd(self.motd)

    def setup_session(self):
        if self.profile.setup_session:
            self.logger.debug("Setup session")
            self.profile.setup_session(self.script)

    def shutdown_session(self):
        if self.profile.shutdown_session:
            self.logger.debug("Shutdown session")
            self.profile.shutdown_session(self.script)

    @tornado.gen.coroutine
    def on_error_sequence(self, seq, command, error_text):
        """
        Process error sequence
        :param seq:
        :param command:
        :param error_text:
        :return:
        """
        if isinstance(seq, six.string_types):
            self.logger.debug("Recovering from error. Sending %r", seq)
            yield self.iostream.write(seq)
        elif callable(seq):
            if tornado.gen.is_coroutine_function(seq):
                # Yield coroutine
                yield seq(self, command, error_text)
            else:
                seq = seq(self, command, error_text)
                yield self.iostream.write(seq)
コード例 #2
0
ファイル: base.py プロジェクト: skripkar/noc
class BaseLoader(object):
    """
    Import directory structure:
    var/
        import/
            <system name>/
                <loader name>/
                    import.csv[.gz]  -- state to load, can have .gz extension
                    mappings.csv -- ID mappings
                    archive/
                        import-YYYY-MM-DD-HH-MM-SS.csv.gz -- imported state

    Import file format: CSV, unix end of lines, UTF-8, comma-separated
    First column - record id in the terms of connected system,
    other columns must be defined in *fields* variable.

    File must be sorted by first field either as string or as numbers,
    sort order must not be changed.

    mappings.csv - CSV, unix end of lines, UTF-8 comma separated
    mappings of ID between NOC and remote system. Populated by loader
    automatically.

    :param fields: List of either field names or tuple of
        (field name, related loader name)
    """
    # Loader name
    name = None
    # Loader model
    model = None
    # Mapped fields
    mapped_fields = {}

    fields = []

    # List of tags to add to the created records
    tags = []

    PREFIX = config.path.etl_import
    rx_archive = re.compile("^import-\d{4}(?:-\d{2}){5}.csv.gz$")

    # Discard records which cannot be dereferenced
    discard_deferred = False
    # Ignore auto-generated unique fields
    ignore_unique = set(["bi_id"])

    REPORT_INTERVAL = 1000

    class Deferred(Exception):
        pass

    def __init__(self, chain):
        self.chain = chain
        self.system = chain.system
        self.logger = PrefixLoggerAdapter(
            logger, "%s][%s" % (self.system.name, self.name))
        self.import_dir = os.path.join(self.PREFIX, self.system.name,
                                       self.name)
        self.archive_dir = os.path.join(self.import_dir, "archive")
        self.mappings_path = os.path.join(self.import_dir, "mappings.csv")
        self.mappings = {}
        self.new_state_path = None
        self.c_add = 0
        self.c_change = 0
        self.c_delete = 0
        # Build clean map
        self.clean_map = dict(
            (n, self.clean_str)
            for n in self.fields)  # field name -> clean function
        self.pending_deletes = []  # (id, string)
        self.reffered_errors = []  # (id, string)
        if self.is_document:
            import mongoengine.errors
            unique_fields = [
                f.name for f in self.model._fields.itervalues()
                if f.unique and f.name not in self.ignore_unique
            ]
            self.integrity_exception = mongoengine.errors.NotUniqueError
        else:
            # Third-party modules
            import django.db.utils

            unique_fields = [
                f.name for f in self.model._meta.fields
                if f.unique and f.name != self.model._meta.pk.name
                and f.name not in self.ignore_unique
            ]
            self.integrity_exception = django.db.utils.IntegrityError
        if unique_fields:
            self.unique_field = unique_fields[0]
        else:
            self.unique_field = None

    @property
    def is_document(self):
        """
        Returns True if model is Document, False - if Model
        """
        return hasattr(self.model, "_fields")

    def load_mappings(self):
        """
        Load mappings file
        """
        if self.model:
            if self.is_document:
                self.update_document_clean_map()
            else:
                self.update_model_clean_map()
        if not os.path.exists(self.mappings_path):
            return
        self.logger.info("Loading mappings from %s", self.mappings_path)
        with open(self.mappings_path) as f:
            reader = csv.reader(f)
            for k, v in reader:
                self.mappings[k] = v
        self.logger.info("%d mappings restored", len(self.mappings))

    def get_new_state(self):
        """
        Returns file object of new state, or None when not present
        """
        # Try import.csv
        path = os.path.join(self.import_dir, "import.csv")
        if os.path.isfile(path):
            logger.info("Loading from %s", path)
            self.new_state_path = path
            return open(path, "r")
        # Try import.csv.gz
        path += ".gz"
        if os.path.isfile(path):
            logger.info("Loading from %s", path)
            self.new_state_path = path
            return gzip.GzipFile(path, "r")
        # No data to import
        return None

    def get_current_state(self):
        """
        Returns file object of current state or None
        """
        self.load_mappings()
        if not os.path.isdir(self.archive_dir):
            self.logger.info("Creating archive directory: %s",
                             self.archive_dir)
            try:
                os.mkdir(self.archive_dir)
            except OSError as e:
                self.logger.error("Failed to create directory: %s (%s)",
                                  self.archive_dir, e)
                # @todo: Die
        if os.path.isdir(self.archive_dir):
            fn = sorted(f for f in os.listdir(self.archive_dir)
                        if self.rx_archive.match(f))
        else:
            fn = []
        if fn:
            path = os.path.join(self.archive_dir, fn[-1])
            logger.info("Current state from %s", path)
            return gzip.GzipFile(path, "r")
        # No current state
        return six.StringIO("")

    def diff(self, old, new):
        """
        Compare old and new CSV files and yield pair of matches
        * old, new -- when changed
        * old, None -- when removed
        * None, new -- when added
        """
        def getnext(g):
            try:
                return next(g)
            except StopIteration:
                return None

        o = getnext(old)
        n = getnext(new)
        while o or n:
            if not o:
                # New
                yield None, n
                n = getnext(new)
            elif not n:
                # Removed
                yield o, None
                o = getnext(old)
            else:
                if n[0] == o[0]:
                    # Changed
                    if n != o:
                        yield o, n
                    n = getnext(new)
                    o = getnext(old)
                elif n[0] < o[0]:
                    # Added
                    yield None, n
                    n = getnext(new)
                else:
                    # Removed
                    yield o, None
                    o = getnext(old)

    def load(self):
        """
        Import new data
        """
        self.logger.info("Importing")
        ns = self.get_new_state()
        if not ns:
            self.logger.info("No new state, skipping")
            self.load_mappings()
            return
        current_state = csv.reader(self.get_current_state())
        new_state = csv.reader(ns)
        deferred_add = []
        deferred_change = []
        for o, n in self.diff(current_state, new_state):
            if o is None and n:
                try:
                    self.on_add(n)
                except self.Deferred:
                    if not self.discard_deferred:
                        deferred_add += [n]
            elif o and n is None:
                self.on_delete(o)
            else:
                try:
                    self.on_change(o, n)
                except self.Deferred:
                    if not self.discard_deferred:
                        deferred_change += [(o, n)]
            rn = self.c_add + self.c_change + self.c_delete
            if rn > 0 and rn % self.REPORT_INTERVAL == 0:
                self.logger.info("   ... %d records", rn)
        # Add deferred records
        while len(deferred_add):
            nd = []
            for row in deferred_add:
                try:
                    self.on_add(row)
                except self.Deferred:
                    nd += [row]
            if len(nd) == len(deferred_add):
                raise Exception("Unable to defer references")
            deferred_add = nd
            rn = self.c_add + self.c_change + self.c_delete
            if rn % self.REPORT_INTERVAL == 0:
                self.logger.info("   ... %d records", rn)
        # Change deferred records
        while len(deferred_change):
            nd = []
            for o, n in deferred_change:
                try:
                    self.on_change(o, n)
                except self.Deferred:
                    nd += [(o, n)]
            if len(nd) == len(deferred_change):
                raise Exception("Unable to defer references")
            deferred_change = nd
            rn = self.c_add + self.c_change + self.c_delete
            if rn % self.REPORT_INTERVAL == 0:
                self.logger.info("   ... %d records", rn)

    def find_object(self, v):
        """
        Find object by remote system/remote id
        :param v:
        :return:
        """
        if not v.get("remote_system") or not v.get("remote_id"):
            self.logger.warning("RS or RID not found")
            return None
        find_query = {
            "remote_system": v.get("remote_system"),
            "remote_id": v.get("remote_id")
        }
        try:
            return self.model.objects.get(**find_query)
        except self.model.MultipleObjectsReturned:
            if self.unique_field:
                find_query[self.unique_field] = v.get(self.unique_field)
                r = self.model.objects.filter(**find_query)
                if not r:
                    r = self.model.objects.filter(**find_query)
                return list(r)[-1]
            raise self.model.MultipleObjectsReturned
        except self.model.DoesNotExist:
            return None

    def create_object(self, v):
        """
        Create object with attributes. Override to save complex
        data structures
        """
        for k, nv in six.iteritems(v):
            if k == "tags":
                # Merge tags
                nv = sorted("%s:%s" % (self.system.name, x) for x in nv)
                v[k] = nv
        o = self.model(**v)
        try:
            o.save()
        except self.integrity_exception as e:
            self.logger.warning("Integrity error: %s", e)
            assert self.unique_field
            if not self.is_document:
                from django.db import connection
                connection._rollback()
            # Fallback to change object
            o = self.model.objects.get(
                **{self.unique_field: v[self.unique_field]})
            for k, nv in six.iteritems(v):
                setattr(o, k, nv)
            o.save()
        return o

    def change_object(self, object_id, v):
        """
        Change object with attributes
        """
        self.logger.debug("Changed object")
        # See: https://code.getnoc.com/noc/noc/merge_requests/49
        try:
            o = self.model.objects.get(pk=object_id)
        except self.model.DoesNotExist:
            self.logger.error("Cannot change %s:%s: Does not exists",
                              self.name, object_id)
            return None
        for k, nv in six.iteritems(v):
            if k == "tags":
                # Merge tags
                ov = o.tags or []
                nv = sorted([
                    x for x in ov if not x.startswith(self.system.name + ":")
                ] + ["%s:%s" % (self.system.name, x) for x in nv])
            setattr(o, k, nv)
        o.save()
        return o

    def on_add(self, row):
        """
        Create new record
        """
        self.logger.debug("Add: %s", ";".join(row))
        v = self.clean(row)
        # @todo: Check record is already exists
        if self.fields[0] in v:
            del v[self.fields[0]]
        if hasattr(self.model, "remote_system"):
            o = self.find_object(v)
        else:
            o = None
        if o:
            self.c_change += 1
            # Lost&found object with same remote_id
            self.logger.debug("Lost and Found object")
            vv = {
                "remote_system": v["remote_system"],
                "remote_id": v["remote_id"]
            }
            # for fn, nv in zip(self.fields[1:], row[1:]):
            for fn, nv in v.iteritems():
                if fn in vv:
                    continue
                if getattr(o, fn) != nv:
                    vv[fn] = nv
            self.change_object(o.id, vv)
            # Restore mappings
            self.set_mappings(row[0], o.id)
        else:
            self.c_add += 1
            o = self.create_object(v)
            self.set_mappings(row[0], o.id)

    def on_change(self, o, n):
        """
        Create change record
        """
        self.logger.debug("Change: %s", ";".join(n))
        self.c_change += 1
        v = self.clean(n)
        vv = {"remote_system": v["remote_system"], "remote_id": v["remote_id"]}
        for fn, (ov, nv) in zip(self.fields[1:],
                                itertools.izip_longest(o[1:], n[1:])):
            if ov != nv:
                self.logger.debug("   %s: %s -> %s", fn, ov, nv)
                vv[fn] = v[fn]
        if n[0] in self.mappings:
            self.change_object(self.mappings[n[0]], vv)
        else:
            self.logger.error("Cannot map id '%s'. Skipping.", n[0])

    def on_delete(self, row):
        """
        Delete record
        """
        self.pending_deletes += [(row[0], ";".join(row))]

    def purge(self):
        """
        Perform pending deletes
        """
        for r_id, msg in reversed(self.pending_deletes):
            self.logger.debug("Delete: %s", msg)
            self.c_delete += 1
            try:
                obj = self.model.objects.get(pk=self.mappings[r_id])
                obj.delete()
            except ValueError as e:  # Reffered Error
                self.logger.error("%s", str(e))
                self.reffered_errors += [(r_id, msg)]
            except self.model.DoesNotExist:
                pass  # Already deleted
        self.pending_deletes = []

    def save_state(self):
        """
        Save current state
        """
        if not self.new_state_path:
            return
        self.logger.info("Summary: %d new, %d changed, %d removed", self.c_add,
                         self.c_change, self.c_delete)
        self.logger.info("Error delete by reffered: %s",
                         "\n".join(self.reffered_errors))
        t = time.localtime()
        archive_path = os.path.join(
            self.archive_dir,
            "import-%04d-%02d-%02d-%02d-%02d-%02d.csv.gz" % tuple(t[:6]))
        self.logger.info("Moving %s to %s", self.new_state_path, archive_path)
        if self.new_state_path.endswith(".gz"):
            # Simply move the file
            shutil.move(self.new_state_path, archive_path)
        else:
            # Compress the file
            self.logger.info("Compressing")
            with open(self.new_state_path, "r") as s:
                with gzip.open(archive_path, "w") as d:
                    d.write(s.read())
            os.unlink(self.new_state_path)
        self.logger.info("Saving mappings to %s", self.mappings_path)
        mdata = "\n".join("%s,%s" % (k, self.mappings[k])
                          for k in sorted(self.mappings))
        safe_rewrite(self.mappings_path, mdata)

    def clean(self, row):
        """
        Cleanup row and return a dict of field name -> value
        """
        r = dict((k, self.clean_map[k](v)) for k, v in zip(self.fields, row))
        # Fill integration fields
        r["remote_system"] = self.system.remote_system
        r["remote_id"] = self.clean_str(row[0])
        return r

    def clean_str(self, value):
        if value:
            if isinstance(value, str):
                return unicode(value, "utf-8")
            elif not isinstance(value, six.string_types):
                return str(value)
            else:
                return value
        else:
            return None

    def clean_map_str(self, mappings, value):
        value = self.clean_str(value)
        if value:
            try:
                value = mappings[value]
            except KeyError:
                raise self.Deferred
        return value

    def clean_bool(self, value):
        if value == "":
            return None
        try:
            return int(value) != 0
        except ValueError:
            pass
        value = value.lower()
        return value in ("t", "true", "y", "yes")

    def clean_reference(self, mappings, r_model, value):
        if not value:
            return None
        else:
            # @todo: Get proper mappings
            try:
                value = mappings[value]
            except KeyError:
                self.logger.info("Deferred. Unknown value %s:%s", r_model,
                                 value)
                raise self.Deferred()
            return self.chain.cache[r_model, value]

    def clean_int_reference(self, mappings, r_model, value):
        if not value:
            return None
        else:
            # @todo: Get proper mappings
            try:
                value = int(mappings[value])
            except KeyError:
                self.logger.info("Deferred. Unknown value %s:%s", r_model,
                                 value)
                raise self.Deferred()
            return self.chain.cache[r_model, value]

    def set_mappings(self, rv, lv):
        self.logger.debug("Set mapping remote: %s, local: %s", rv, lv)
        self.mappings[str(rv)] = str(lv)

    def update_document_clean_map(self):
        from mongoengine.fields import BooleanField, ReferenceField
        from noc.lib.nosql import PlainReferenceField, ForeignKeyField

        for fn, ft in six.iteritems(self.model._fields):
            if fn not in self.clean_map:
                continue
            if isinstance(ft, BooleanField):
                self.clean_map[fn] = self.clean_bool
            elif isinstance(ft, (PlainReferenceField, ReferenceField)):
                if fn in self.mapped_fields:
                    self.clean_map[fn] = functools.partial(
                        self.clean_reference,
                        self.chain.get_mappings(self.mapped_fields[fn]),
                        ft.document_type)
            elif isinstance(ft, ForeignKeyField):
                if fn in self.mapped_fields:
                    self.clean_map[fn] = functools.partial(
                        self.clean_int_reference,
                        self.chain.get_mappings(self.mapped_fields[fn]),
                        ft.document_type)
            elif fn in self.mapped_fields:
                self.clean_map[fn] = functools.partial(
                    self.clean_map_str,
                    self.chain.get_mappings(self.mapped_fields[fn]))

    def update_model_clean_map(self):
        from django.db.models import BooleanField, ForeignKey
        from noc.core.model.fields import DocumentReferenceField

        for f in self.model._meta.fields:
            if f.name not in self.clean_map:
                continue
            if isinstance(f, BooleanField):
                self.clean_map[f.name] = self.clean_bool
            elif isinstance(f, DocumentReferenceField):
                if f.name in self.mapped_fields:
                    self.clean_map[f.name] = functools.partial(
                        self.clean_reference,
                        self.chain.get_mappings(self.mapped_fields[f.name]),
                        f.document)
            elif isinstance(f, ForeignKey):
                if f.name in self.mapped_fields:
                    self.clean_map[f.name] = functools.partial(
                        self.clean_reference,
                        self.chain.get_mappings(self.mapped_fields[f.name]),
                        f.rel.to)
            elif f.name in self.mapped_fields:
                self.clean_map[f.name] = functools.partial(
                    self.clean_map_str,
                    self.chain.get_mappings(self.mapped_fields[f.name]))

    def check(self, chain):
        self.logger.info("Checking")
        # Get constraints
        if self.is_document:
            # Document
            required_fields = [
                f.name for f in self.model._fields.itervalues()
                if f.required or f.unique
            ]
            unique_fields = [
                f.name for f in self.model._fields.itervalues() if f.unique
            ]
        else:
            # Model
            required_fields = [
                f.name for f in self.model._meta.fields if not f.blank
            ]
            unique_fields = [
                f.name for f in self.model._meta.fields
                if f.unique and f.name != self.model._meta.pk.name
            ]
        if not required_fields and not unique_fields:
            self.logger.info("Nothing to check, skipping")
            return 0
        # Prepare data
        ns = self.get_new_state()
        if not ns:
            self.logger.info("No new state, skipping")
            return 0
        new_state = csv.reader(ns)
        r_index = set(
            self.fields.index(f) for f in required_fields if f in self.fields)
        u_index = set(
            self.fields.index(f) for f in unique_fields
            if f not in self.ignore_unique)
        m_index = set(self.fields.index(f) for f in self.mapped_fields)
        uv = set()
        m_data = {}  # field_number -> set of mapped ids
        # Load mapped ids
        for f in self.mapped_fields:
            line = chain.get_loader(self.mapped_fields[f])
            ls = line.get_new_state()
            if not ls:
                ls = line.get_current_state()
            ms = csv.reader(ls)
            m_data[self.fields.index(f)] = set(row[0] for row in ms)
        # Process data
        n_errors = 0
        for row in new_state:
            lr = len(row)
            # Check required fields
            for i in r_index:
                if not row[i]:
                    self.logger.error(
                        "ERROR: Required field #%d(%s) is missed in row: %s",
                        i, self.fields[i], ",".join(row))
                    n_errors += 1
                    continue
            # Check unique fields
            for i in u_index:
                v = row[i]
                if (i, v) in uv:
                    self.logger.error(
                        "ERROR: Field #%d(%s) value is not unique: %s", i,
                        self.fields[i], ",".join(row))
                    n_errors += 1
                else:
                    uv.add((i, v))
            # Check mapped fields
            for i in m_index:
                if i >= lr:
                    continue
                v = row[i]
                if v and v not in m_data[i]:
                    self.logger.error(
                        "ERROR: Field #%d(%s) == '%s' refers to non-existent record: %s",
                        i, self.fields[i], row[i], ",".join(row))
                    n_errors += 1
        if n_errors:
            self.logger.info("%d errors found", n_errors)
        else:
            self.logger.info("No errors found")
        return n_errors

    def check_diff(self):
        def dump(cmd, row):
            print("%s %s" % (cmd, ",".join(row)))

        print("--- %s.%s" % (self.chain.system.name, self.name))
        ns = self.get_new_state()
        if not ns:
            return
        current_state = csv.reader(self.get_current_state())
        new_state = csv.reader(ns)
        for o, n in self.diff(current_state, new_state):
            if o is None and n:
                dump("+", n)
            elif o and n is None:
                dump("-", o)
            else:
                dump("/", o)
                dump("\\", n)

    def check_diff_summary(self):
        i, u, d = 0, 0, 0
        ns = self.get_new_state()
        if not ns:
            return i, u, d
        current_state = csv.reader(self.get_current_state())
        new_state = csv.reader(ns)
        for o, n in self.diff(current_state, new_state):
            if o is None and n:
                i += 1
            elif o and n is None:
                d += 1
            else:
                u += 1
        return i, u, d
コード例 #3
0
class BaseScript(six.with_metaclass(BaseScriptMetaclass, object)):
    """
    Service Activation script base class
    """

    # Script name in form of <vendor>.<system>.<name>
    name = None
    # Default script timeout
    TIMEOUT = config.script.timeout
    # Default session timeout
    SESSION_IDLE_TIMEOUT = config.script.session_idle_timeout
    # Default access preferene
    DEFAULT_ACCESS_PREFERENCE = "SC"
    # Enable call cache
    # If True, script result will be cached and reused
    # during lifetime of parent script
    cache = False
    # Implemented interface
    interface = None
    # Scripts required by generic script.
    # For common scripts - empty list
    # For generics - list of pairs (script_name, interface)
    requires = []
    #
    base_logger = logging.getLogger(name or "script")
    #
    _x_seq = itertools.count()
    # Sessions
    session_lock = Lock()
    session_cli = {}
    session_mml = {}
    session_rtsp = {}
    # In session mode when active CLI session exists
    # * True -- reuse session
    # * False -- close session and run new without session context
    reuse_cli_session = True
    # In session mode:
    # Should we keep CLI session for reuse by next script
    # * True - keep CLI session for next script
    # * False - close CLI session
    keep_cli_session = True
    # Script-level matchers.
    # Override profile one
    matchers = {}

    # Error classes shortcuts
    ScriptError = ScriptError
    CLISyntaxError = CLISyntaxError
    CLIOperationError = CLIOperationError
    NotSupportedError = NotSupportedError
    UnexpectedResultError = UnexpectedResultError

    hexbin = {
        "0": "0000",
        "1": "0001",
        "2": "0010",
        "3": "0011",
        "4": "0100",
        "5": "0101",
        "6": "0110",
        "7": "0111",
        "8": "1000",
        "9": "1001",
        "a": "1010",
        "b": "1011",
        "c": "1100",
        "d": "1101",
        "e": "1110",
        "f": "1111",
    }

    cli_protocols = {
        "telnet": "noc.core.script.cli.telnet.TelnetCLI",
        "ssh": "noc.core.script.cli.ssh.SSHCLI",
        "beef": "noc.core.script.cli.beef.BeefCLI",
    }

    mml_protocols = {"telnet": "noc.core.script.mml.telnet.TelnetMML"}

    rtsp_protocols = {"tcp": "noc.core.script.rtsp.base.RTSPBase"}
    # Override access preferences for script
    # S - always try SNMP first
    # C - always try CLI first
    # None - use default preferences
    always_prefer = None

    def __init__(
        self,
        service,
        credentials,
        args=None,
        capabilities=None,
        version=None,
        parent=None,
        timeout=None,
        name=None,
        session=None,
        session_idle_timeout=None,
    ):
        self.service = service
        self.tos = config.activator.tos
        self.pool = config.pool
        self.parent = parent
        self._motd = None
        name = name or self.name
        self.logger = PrefixLoggerAdapter(
            self.base_logger, "%s] [%s" % (self.name, credentials.get("address", "-"))
        )
        if self.parent:
            self.profile = self.parent.profile
        else:
            self.profile = profile_loader.get_profile(".".join(name.split(".")[:2]))()
        self.credentials = credentials or {}
        self.version = version or {}
        self.capabilities = capabilities or {}
        self.timeout = timeout or self.get_timeout()
        self.start_time = None
        self._interface = self.interface()
        self.args = self.clean_input(args) if args else {}
        self.cli_stream = None
        self.mml_stream = None
        self.rtsp_stream = None
        if self.parent:
            self.snmp = self.root.snmp
        elif self.is_beefed:
            self.snmp = BeefSNMP(self)
        else:
            self.snmp = SNMP(self)
        if self.parent:
            self.http = self.root.http
        else:
            self.http = HTTP(self)
        self.to_disable_pager = not self.parent and self.profile.command_disable_pager
        self.scripts = ScriptsHub(self)
        # Store session id
        self.session = session
        self.session_idle_timeout = session_idle_timeout or self.SESSION_IDLE_TIMEOUT
        # Cache CLI and SNMP calls, if set
        self.is_cached = False
        # Suitable only when self.parent is None.
        # Cached results for scripts marked with "cache"
        self.call_cache = {}
        # Suitable only when self.parent is None
        # Cached results of self.cli calls
        self.cli_cache = {}
        #
        self.http_cache = {}
        self.partial_result = None
        # Tracking
        self.to_track = False
        self.cli_tracked_data = {}  # command -> [packets]
        self.cli_tracked_command = None
        # state -> [..]
        self.cli_fsm_tracked_data = {}
        #
        if not parent and version and not name.endswith(".get_version"):
            self.logger.debug("Filling get_version cache with %s", version)
            s = name.split(".")
            self.set_cache("%s.%s.get_version" % (s[0], s[1]), {}, version)
        # Fill matchers
        if not self.name.endswith(".get_version"):
            self.apply_matchers()
        #
        if self.profile.setup_script:
            self.profile.setup_script(self)

    def __call__(self, *args, **kwargs):
        self.args = kwargs
        return self.run()

    def apply_matchers(self):
        """
        Process matchers and apply is_XXX properties
        :return:
        """

        def get_matchers(c, matchers):
            return dict((m, match(c, matchers[m])) for m in matchers)

        # Match context
        # @todo: Add capabilities
        ctx = self.version or {}
        if self.capabilities:
            ctx["caps"] = self.capabilities
        # Calculate matches
        v = get_matchers(ctx, self.profile.matchers)
        v.update(get_matchers(ctx, self.matchers))
        #
        for k in v:
            self.logger.debug("%s = %s", k, v[k])
            setattr(self, k, v[k])

    def clean_input(self, args):
        """
        Cleanup input parameters against interface
        """
        return self._interface.script_clean_input(self.profile, **args)

    def clean_output(self, result):
        """
        Clean script result against interface
        """
        return self._interface.script_clean_result(self.profile, result)

    def run(self):
        """
        Run script
        """
        with Span(server="activator", service=self.name, in_label=self.credentials.get("address")):
            self.start_time = perf_counter()
            self.logger.debug("Running. Input arguments: %s, timeout %s", self.args, self.timeout)
            # Use cached result when available
            cache_hit = False
            if self.cache and self.parent:
                try:
                    result = self.get_cache(self.name, self.args)
                    self.logger.info("Using cached result")
                    cache_hit = True
                except KeyError:
                    pass
            # Execute script
            if not cache_hit:
                try:
                    result = self.execute(**self.args)
                    if self.cache and self.parent and result:
                        self.logger.info("Caching result")
                        self.set_cache(self.name, self.args, result)
                finally:
                    if not self.parent:
                        # Close SNMP socket when necessary
                        self.close_snmp()
                        # Close CLI socket when necessary
                        self.close_cli_stream()
                        # Close MML socket when necessary
                        self.close_mml_stream()
                        # Close RTSP socket when necessary
                        self.close_rtsp_stream()
                        # Close HTTP Client
                        self.http.close()
            # Clean result
            result = self.clean_output(result)
            self.logger.debug("Result: %s", result)
            runtime = perf_counter() - self.start_time
            self.logger.info("Complete (%.2fms)", runtime * 1000)
        return result

    @classmethod
    def compile_match_filter(cls, *args, **kwargs):
        """
        Compile arguments into version check function
        Returns callable accepting self and version hash arguments
        """
        c = [lambda self, x, g=f: g(x) for f in args]
        for k, v in six.iteritems(kwargs):
            # Split to field name and lookup operator
            if "__" in k:
                f, o = k.split("__")
            else:
                f = k
                o = "exact"
                # Check field name
            if f not in ("vendor", "platform", "version", "image"):
                raise Exception("Invalid field '%s'" % f)
                # Compile lookup functions
            if o == "exact":
                c += [lambda self, x, f=f, v=v: x[f] == v]
            elif o == "iexact":
                c += [lambda self, x, f=f, v=v: x[f].lower() == v.lower()]
            elif o == "startswith":
                c += [lambda self, x, f=f, v=v: x[f].startswith(v)]
            elif o == "istartswith":
                c += [lambda self, x, f=f, v=v: x[f].lower().startswith(v.lower())]
            elif o == "endswith":
                c += [lambda self, x, f=f, v=v: x[f].endswith(v)]
            elif o == "iendswith":
                c += [lambda self, x, f=f, v=v: x[f].lower().endswith(v.lower())]
            elif o == "contains":
                c += [lambda self, x, f=f, v=v: v in x[f]]
            elif o == "icontains":
                c += [lambda self, x, f=f, v=v: v.lower() in x[f].lower()]
            elif o == "in":
                c += [lambda self, x, f=f, v=v: x[f] in v]
            elif o == "regex":
                c += [lambda self, x, f=f, v=re.compile(v): v.search(x[f]) is not None]
            elif o == "iregex":
                c += [
                    lambda self, x, f=f, v=re.compile(v, re.IGNORECASE): v.search(x[f]) is not None
                ]
            elif o == "isempty":  # Empty string or null
                c += [lambda self, x, f=f, v=v: not x[f] if v else x[f]]
            elif f == "version":
                if o == "lt":  # <
                    c += [lambda self, x, v=v: self.profile.cmp_version(x["version"], v) < 0]
                elif o == "lte":  # <=
                    c += [lambda self, x, v=v: self.profile.cmp_version(x["version"], v) <= 0]
                elif o == "gt":  # >
                    c += [lambda self, x, v=v: self.profile.cmp_version(x["version"], v) > 0]
                elif o == "gte":  # >=
                    c += [lambda self, x, v=v: self.profile.cmp_version(x["version"], v) >= 0]
                else:
                    raise Exception("Invalid lookup operation: %s" % o)
            else:
                raise Exception("Invalid lookup operation: %s" % o)
        # Combine expressions into single lambda
        return reduce(
            lambda x, y: lambda self, v, x=x, y=y: (x(self, v) and y(self, v)),
            c,
            lambda self, x: True,
        )

    @classmethod
    def match(cls, *args, **kwargs):
        """
        execute method decorator
        """

        def wrap(f):
            # Append to the execute chain
            if hasattr(f, "_match"):
                old_filter = f._match
                f._match = lambda self, v, old_filter=old_filter, new_filter=new_filter: new_filter(
                    self, v
                ) or old_filter(self, v)
            else:
                f._match = new_filter
            f._seq = next(cls._x_seq)
            return f

        # Compile check function
        new_filter = cls.compile_match_filter(*args, **kwargs)
        # Return decorated function
        return wrap

    def match_version(self, *args, **kwargs):
        """
        inline version for BaseScript.match
        """
        if not self.version:
            self.version = self.scripts.get_version()
        return self.compile_match_filter(*args, **kwargs)(self, self.version)

    def execute(self, **kwargs):
        """
        Default script behavior:
        Pass through _execute_chain and call appropriate handler
        """
        if self._execute_chain and not self.name.endswith(".get_version"):
            # Deprecated @match chain
            self.logger.info(
                "WARNING: Using deprecated @BaseScript.match() decorator. "
                "Consider porting to the new matcher API"
            )
            # Get version information
            if not self.version:
                self.version = self.scripts.get_version()
            # Find and execute proper handler
            for f in self._execute_chain:
                if f._match(self, self.version):
                    return f(self, **kwargs)
                # Raise error
            raise self.NotSupportedError()
        else:
            # New SNMP/CLI API
            return self.call_method(
                cli_handler=self.execute_cli, snmp_handler=self.execute_snmp, **kwargs
            )

    def call_method(self, cli_handler=None, snmp_handler=None, fallback_handler=None, **kwargs):
        """
        Call function depending on access_preference
        :param cli_handler: String or callable to call on CLI access method
        :param snmp_handler: String or callable to call on SNMP access method
        :param fallback_handler: String or callable to call if no access method matched
        :param kwargs:
        :return:
        """
        # Select proper handler
        access_preference = self.get_access_preference() + "*"
        for m in access_preference:
            # Select proper handler
            if m == "C":
                handler = cli_handler
            elif m == "S":
                if self.has_snmp():
                    handler = snmp_handler
                else:
                    self.logger.debug("SNMP is not enabled. Passing to next method")
                    continue
            elif m == "*":
                handler = fallback_handler
            else:
                raise self.NotSupportedError("Invalid access method '%s'" % m)
            # Resolve handler when necessary
            if isinstance(handler, six.string_types):
                handler = getattr(self, handler, None)
            if handler is None:
                self.logger.debug("No '%s' handler. Passing to next method" % m)
                continue
            # Call handler
            try:
                r = handler(**kwargs)
                if isinstance(r, PartialResult):
                    if self.partial_result:
                        self.partial_result.update(r.result)
                    else:
                        self.partial_result = r.result
                    self.logger.debug(
                        "Partial result: %r. Passing to next method", self.partial_result
                    )
                else:
                    return r
            except self.snmp.TimeOutError:
                self.logger.info("SNMP timeout. Passing to next method")
                if access_preference == "S*":
                    self.logger.info("Last S method break by timeout.")
                    raise self.snmp.TimeOutError
            except NotImplementedError:
                self.logger.debug(
                    "Access method '%s' is not implemented. Passing to next method", m
                )
        raise self.NotSupportedError(
            "Access preference '%s' is not supported" % access_preference[:-1]
        )

    def execute_cli(self, **kwargs):
        """
        Process script using CLI
        :param kwargs:
        :return:
        """
        raise NotImplementedError("execute_cli() is not implemented")

    def execute_snmp(self, **kwargs):
        """
        Process script using SNMP
        :param kwargs:
        :return:
        """
        raise NotImplementedError("execute_snmp() is not implemented")

    def cleaned_config(self, config):
        """
        Clean up config from all unnecessary trash
        """
        return self.profile.cleaned_config(config)

    def strip_first_lines(self, text, lines=1):
        """
        Strip first *lines*
        """
        t = text.split("\n")
        if len(t) <= lines:
            return ""
        else:
            return "\n".join(t[lines:])

    def expand_rangelist(self, s):
        """
        Expand expressions like "1,2,5-7" to [1, 2, 5, 6, 7]
        """
        result = {}
        for x in s.split(","):
            x = x.strip()
            if x == "":
                continue
            if "-" in x:
                left, right = [int(y) for y in x.split("-")]
                if left > right:
                    x = right
                    right = left
                    left = x
                for i in range(left, right + 1):
                    result[i] = None
            else:
                result[int(x)] = None
        return sorted(result.keys())

    rx_detect_sep = re.compile("^(.*?)\d+$")

    def expand_interface_range(self, s):
        """
        Convert interface range expression to a list
        of interfaces
        "Gi 1/1-3,Gi 1/7" -> ["Gi 1/1", "Gi 1/2", "Gi 1/3", "Gi 1/7"]
        "1:1-3" -> ["1:1", "1:2", "1:3"]
        "1:1-1:3" -> ["1:1", "1:2", "1:3"]
        :param s: Comma-separated list
        :return:
        """
        r = set()
        for x in s.split(","):
            x = x.strip()
            if not x:
                continue
            if "-" in x:
                # Expand range
                f, t = [y.strip() for y in x.split("-")]
                # Detect common prefix
                match = self.rx_detect_sep.match(f)
                if not match:
                    raise ValueError(x)
                prefix = match.group(1)
                # Detect range boundaries
                start = int(f[len(prefix) :])
                if is_int(t):
                    stop = int(t)  # Just integer
                else:
                    if not t.startswith(prefix):
                        raise ValueError(x)
                    stop = int(t[len(prefix) :])  # Prefixed
                if start > stop:
                    raise ValueError(x)
                for i in range(start, stop + 1):
                    r.add(prefix + str(i))
            else:
                r.add(x)
        return sorted(r)

    def macs_to_ranges(self, macs):
        """
        Converts list of macs to rangea
        :param macs: Iterable yielding mac addresses
        :returns: [(from, to), ..]
        """
        r = []
        for m in sorted(MAC(x) for x in macs):
            if r:
                if r[-1][1].shift(1) == m:
                    # Expand last range
                    r[-1][1] = m
                else:
                    r += [[m, m]]
            else:
                r += [[m, m]]
        return [(str(x[0]), str(x[1])) for x in r]

    def hexstring_to_mac(self, s):
        """Convert a 6-octet string to MAC address"""
        return ":".join(["%02X" % ord(x) for x in s])

    @property
    def root(self):
        """Get root script"""
        if self.parent:
            return self.parent.root
        else:
            return self

    def get_cache(self, key1, key2):
        """Get cached result or raise KeyError"""
        s = self.root
        return s.call_cache[repr(key1)][repr(key2)]

    def set_cache(self, key1, key2, value):
        """Set cached result"""
        key1 = repr(key1)
        key2 = repr(key2)
        s = self.root
        if key1 not in s.call_cache:
            s.call_cache[key1] = {}
        s.call_cache[key1][key2] = value

    def configure(self):
        """Returns configuration context"""
        return ConfigurationContextManager(self)

    def cached(self):
        """
        Return cached context managed. All nested CLI and SNMP GET/GETNEXT
        calls will be cached.

        Usage:

        with self.cached():
            self.cli(".....)
            self.scripts.script()
        """
        return CacheContextManager(self)

    def enter_config(self):
        """Enter configuration mote"""
        if self.profile.command_enter_config:
            self.cli(self.profile.command_enter_config)

    def leave_config(self):
        """Leave configuration mode"""
        if self.profile.command_leave_config:
            self.cli(self.profile.command_leave_config)
            self.cli("")  # Guardian empty command to wait until configuration is finally written

    def save_config(self, immediately=False):
        """Save current config"""
        if immediately:
            if self.profile.command_save_config:
                self.cli(self.profile.command_save_config)
        else:
            self.schedule_to_save()

    def schedule_to_save(self):
        self.need_to_save = True
        if self.parent:
            self.parent.schedule_to_save()

    def set_motd(self, motd):
        self._motd = motd

    @property
    def motd(self):
        """
        Return message of the day
        """
        if self._motd:
            return self._motd
        else:
            return self.get_cli_stream().get_motd()

    def re_search(self, rx, s, flags=0):
        """
        Match s against regular expression rx using re.search
        Raise UnexpectedResultError if regular expression is not matched.
        Returns match object.
        rx can be string or compiled regular expression
        """
        if isinstance(rx, six.string_types):
            rx = re.compile(rx, flags)
        match = rx.search(s)
        if match is None:
            raise UnexpectedResultError()
        return match

    def re_match(self, rx, s, flags=0):
        """
        Match s against regular expression rx using re.match
        Raise UnexpectedResultError if regular expression is not matched.
        Returns match object.
        rx can be string or compiled regular expression
        """
        if isinstance(rx, six.string_types):
            rx = re.compile(rx, flags)
        match = rx.match(s)
        if match is None:
            raise UnexpectedResultError()
        return match

    _match_lines_cache = {}

    @classmethod
    def match_lines(cls, rx, s):
        k = id(rx)
        if k not in cls._match_lines_cache:
            _rx = [re.compile(line, re.IGNORECASE) for line in rx]
            cls._match_lines_cache[k] = _rx
        else:
            _rx = cls._match_lines_cache[k]
        ctx = {}
        idx = 0
        r = _rx[0]
        for line in s.splitlines():
            line = line.strip()
            match = r.search(line)
            if match:
                ctx.update(match.groupdict())
                idx += 1
                if idx == len(_rx):
                    return ctx
                r = _rx[idx]
        return None

    def find_re(self, iter, s):
        """
        Find first matching regular expression
        or raise Unexpected result error
        """
        for r in iter:
            if r.search(s):
                return r
        raise UnexpectedResultError()

    def hex_to_bin(self, s):
        """
        Convert hexadecimal string to boolean string.
        All non-hexadecimal characters are ignored
        :param s: Input string
        :return: Boolean string
        :rtype: str
        """
        return "".join(self.hexbin[c] for c in "".join("%02x" % ord(d) for d in s))

    def push_prompt_pattern(self, pattern):
        self.get_cli_stream().push_prompt_pattern(pattern)

    def pop_prompt_pattern(self):
        self.get_cli_stream().pop_prompt_pattern()

    def has_oid(self, oid):
        """
        Check object responses to oid
        """
        try:
            return bool(self.snmp.get(oid))
        except self.snmp.TimeOutError:
            return False

    def get_timeout(self):
        return self.TIMEOUT

    def cli(
        self,
        cmd,
        command_submit=None,
        bulk_lines=None,
        list_re=None,
        cached=False,
        file=None,
        ignore_errors=False,
        allow_empty_response=True,
        nowait=False,
        obj_parser=None,
        cmd_next=None,
        cmd_stop=None,
    ):
        """
        Execute CLI command and return result. Initiate cli session
        when necessary
        :param cmd: CLI command to execute
        :param command_submit:
        :param bulk_lines:
        :param list_re:
        :param cached:
        :param file:
        :param ignore_errors:
        :param allow_empty_response: Allow empty output. If False - ignore prompt and wait output
        :type allow_empty_response: bool
        :param nowait:

        Execute CLI command and return a result.
        if list_re is None, return a string
        if list_re is regular expression object, return a list of dicts (group name -> value),
            one dict per matched line
        """

        def format_result(result):
            if list_re:
                x = []
                for l in result.splitlines():
                    match = list_re.match(l.strip())
                    if match:
                        x += [match.groupdict()]
                return x
            else:
                return result

        if file:
            with open(file) as f:
                return format_result(f.read())
        if cached:
            r = self.root.cli_cache.get(cmd)
            if r is not None:
                self.logger.debug("Use cached result")
                return format_result(r)
        command_submit = command_submit or self.profile.command_submit
        stream = self.get_cli_stream()
        if self.to_track:
            self.cli_tracked_command = cmd
        r = stream.execute(
            cmd + command_submit,
            obj_parser=obj_parser,
            cmd_next=cmd_next,
            cmd_stop=cmd_stop,
            ignore_errors=ignore_errors,
            allow_empty_response=allow_empty_response,
        )
        if isinstance(r, six.string_types):
            # Check for syntax errors
            if not ignore_errors:
                # Then check for operation error
                if (
                    self.profile.rx_pattern_operation_error
                    and self.profile.rx_pattern_operation_error.search(r)
                ):
                    raise self.CLIOperationError(r)
            # Echo cancelation
            if r[:4096].lstrip().startswith(cmd):
                r = r.lstrip()
                if r.startswith(cmd + "\n"):
                    # Remove first line
                    r = self.strip_first_lines(r.lstrip())
                else:
                    # Some switches, like ProCurve do not send \n after the echo
                    r = r[len(cmd) :]
            # Store cli cache when necessary
            if cached:
                self.root.cli_cache[cmd] = r
        return format_result(r)

    def get_cli_stream(self):
        if self.parent:
            return self.root.get_cli_stream()
        if not self.cli_stream and self.session:
            # Try to get cached session's CLI
            with self.session_lock:
                self.cli_stream = self.session_cli.get(self.session)
                if self.cli_stream:
                    if self.cli_stream.is_closed:
                        # Stream closed by external reason,
                        # mark as invalid and start new one
                        self.cli_stream = None
                    # Remove stream from pool to prevent cli session hijacking
                    del self.session_cli[self.session]
            if self.cli_stream:
                if self.to_reuse_cli_session():
                    self.logger.debug("Using cached session's CLI")
                    self.cli_stream.set_script(self)
                else:
                    self.logger.debug("Script cannot reuse existing CLI session, starting new one")
                    self.close_cli_stream()
        if not self.cli_stream:
            protocol = self.credentials.get("cli_protocol", "telnet")
            self.logger.debug("Open %s CLI", protocol)
            self.cli_stream = get_handler(self.cli_protocols[protocol])(self, tos=self.tos)
            # Store to the sessions
            if self.session:
                with self.session_lock:
                    self.session_cli[self.session] = self.cli_stream
            self.cli_stream.setup_session()
            # Disable pager when nesessary
            # @todo: Move to CLI
            if self.to_disable_pager:
                self.logger.debug("Disable paging")
                self.to_disable_pager = False
                if isinstance(self.profile.command_disable_pager, six.string_types):
                    self.cli(self.profile.command_disable_pager, ignore_errors=True)
                elif isinstance(self.profile.command_disable_pager, list):
                    for cmd in self.profile.command_disable_pager:
                        self.cli(cmd, ignore_errors=True)
                else:
                    raise UnexpectedResultError
        return self.cli_stream

    def close_cli_stream(self):
        if self.parent:
            return
        if self.cli_stream:
            if self.session and self.to_keep_cli_session():
                # Return cli stream to pool
                self.session_cli[self.session] = self.cli_stream
                # Schedule stream closing
                self.cli_stream.deferred_close(self.session_idle_timeout)
            else:
                self.cli_stream.shutdown_session()
                self.cli_stream.close()
            self.cli_stream = None

    def close_snmp(self):
        if self.parent:
            return
        if self.snmp:
            self.snmp.close()
            self.snmp = None

    def mml(self, cmd, **kwargs):
        """
        Execute MML command and return result. Initiate MML session when necessary
        :param cmd:
        :param kwargs:
        :return:
        """
        stream = self.get_mml_stream()
        r = stream.execute(cmd, **kwargs)
        return r

    def get_mml_stream(self):
        if self.parent:
            return self.root.get_mml_stream()
        if not self.mml_stream and self.session:
            # Try to get cached session's CLI
            with self.session_lock:
                self.mml_stream = self.session_mml.get(self.session)
                if self.mml_stream and self.mml_stream.is_closed:
                    self.mml_stream = None
                    del self.session_mml[self.session]
            if self.mml_stream:
                if self.to_reuse_cli_session():
                    self.logger.debug("Using cached session's MML")
                    self.mml_stream.set_script(self)
                else:
                    self.logger.debug("Script cannot reuse existing MML session, starting new one")
                    self.close_mml_stream()
        if not self.mml_stream:
            protocol = self.credentials.get("cli_protocol", "telnet")
            self.logger.debug("Open %s MML", protocol)
            self.mml_stream = get_handler(self.mml_protocols[protocol])(self, tos=self.tos)
            # Store to the sessions
            if self.session:
                with self.session_lock:
                    self.session_mml[self.session] = self.mml_stream
        return self.mml_stream

    def close_mml_stream(self):
        if self.parent:
            return
        if self.mml_stream:
            if self.session and self.to_keep_cli_session():
                self.mml_stream.deferred_close(self.session_idle_timeout)
            else:
                self.mml_stream.close()
            self.cli_stream = None

    def rtsp(self, method, path, **kwargs):
        """
        Execute RTSP command and return result. Initiate RTSP session when necessary
        :param method:
        :param path:
        :param kwargs:
        :return:
        """
        stream = self.get_rtsp_stream()
        r = stream.execute(path, method, **kwargs)
        return r

    def get_rtsp_stream(self):
        if self.parent:
            return self.root.get_rtsp_stream()
        if not self.rtsp_stream and self.session:
            # Try to get cached session's CLI
            with self.session_lock:
                self.rtsp_stream = self.session_rtsp.get(self.session)
                if self.rtsp_stream and self.rtsp_stream.is_closed:
                    self.rtsp_stream = None
                    del self.session_rtsp[self.session]
            if self.rtsp_stream:
                if self.to_reuse_cli_session():
                    self.logger.debug("Using cached session's RTSP")
                    self.rtsp_stream.set_script(self)
                else:
                    self.logger.debug("Script cannot reuse existing RTSP session, starting new one")
                    self.close_rtsp_stream()
        if not self.rtsp_stream:
            protocol = "tcp"
            self.logger.debug("Open %s RTSP", protocol)
            self.rtsp_stream = get_handler(self.rtsp_protocols[protocol])(self, tos=self.tos)
            # Store to the sessions
            if self.session:
                with self.session_lock:
                    self.session_rtsp[self.session] = self.rtsp_stream
        return self.rtsp_stream

    def close_rtsp_stream(self):
        if self.parent:
            return
        if self.rtsp_stream:
            if self.session and self.to_keep_cli_session():
                self.rtsp_stream.deferred_close(self.session_idle_timeout)
            else:
                self.rtsp_stream.close()
            self.cli_stream = None

    def close_current_session(self):
        if self.session:
            self.close_session(self.session)

    @classmethod
    def close_session(cls, session_id):
        """
        Explicit session closing
        :return:
        """
        with cls.session_lock:
            cli_stream = cls.session_cli.get(session_id)
            if cli_stream:
                del cls.session_cli[session_id]
            mml_stream = cls.session_mml.get(session_id)
            if mml_stream:
                del cls.session_mml[session_id]
            rtsp_stream = cls.session_rtsp.get(session_id)
            if rtsp_stream:
                del cls.session_rtsp[session_id]
        if cli_stream and not cli_stream.is_closed:
            cli_stream.shutdown_session()
            cli_stream.close()
        if mml_stream and not mml_stream.is_closed:
            mml_stream.shutdown_session()
            mml_stream.close()
        if rtsp_stream and not rtsp_stream.is_closed:
            rtsp_stream.shutdown_session()
            rtsp_stream.close()

    def get_access_preference(self):
        preferred = self.get_always_preferred()
        r = self.credentials.get("access_preference", self.DEFAULT_ACCESS_PREFERENCE)
        if preferred and preferred in r:
            return preferred + "".join(x for x in r if x != preferred)
        return r

    def get_always_preferred(self):
        """
        Return always preferred access method
        :return:
        """
        return self.always_prefer

    def has_cli_access(self):
        return "C" in self.get_access_preference()

    def has_snmp_access(self):
        return "S" in self.get_access_preference() and self.has_snmp()

    def has_cli_only_access(self):
        return self.has_cli_access() and not self.has_snmp_access()

    def has_snmp_only_access(self):
        return not self.has_cli_access() and self.has_snmp_access()

    def has_snmp(self):
        """
        Check whether equipment has SNMP enabled
        """
        if self.has_capability("SNMP", allow_zero=True):
            # If having SNMP caps - check it and credential
            return bool(self.credentials.get("snmp_ro")) and self.has_capability("SNMP")
        else:
            # if SNMP caps not exist check credential
            return bool(self.credentials.get("snmp_ro"))

    def has_snmp_v1(self):
        return self.has_capability("SNMP | v1")

    def has_snmp_v2c(self):
        return self.has_capability("SNMP | v2c")

    def has_snmp_v3(self):
        return self.has_capability("SNMP | v3")

    def has_snmp_bulk(self):
        """
        Check whether equipment supports SNMP BULK
        """
        return self.has_capability("SNMP | Bulk")

    def has_capability(self, capability, allow_zero=False):
        """
        Check whether equipment supports capability
        """
        if allow_zero:
            return self.capabilities.get(capability) is not None
        else:
            return bool(self.capabilities.get(capability))

    def ignored_exceptions(self, iterable):
        """
        Context manager to silently ignore specified exceptions
        """
        return IgnoredExceptionsContextManager(iterable)

    def iter_pairs(self, g, offset=0):
        """
        Convert iterable g to a pairs
        i.e.
        [1, 2, 3, 4] -> [(1, 2), (3, 4)]
        :param g: Iterable
        :param offset: Skip first recirds
        :return:
        """
        g = iter(g)
        if offset:
            for _ in range(offset):
                next(g)
        return zip(g, g)

    def to_reuse_cli_session(self):
        return self.reuse_cli_session

    def to_keep_cli_session(self):
        return self.keep_cli_session

    def start_tracking(self):
        self.logger.debug("Start tracking")
        self.to_track = True

    def stop_tracking(self):
        self.logger.debug("Stop tracking")
        self.to_track = False
        self.cli_tracked_data = {}

    def push_cli_tracking(self, r, state):
        if state == "prompt":
            if self.cli_tracked_command in self.cli_tracked_data:
                self.cli_tracked_data[self.cli_tracked_command] += [r]
            else:
                self.cli_tracked_data[self.cli_tracked_command] = [r]
        elif state in self.cli_fsm_tracked_data:
            self.cli_fsm_tracked_data[state] += [r]
        else:
            self.cli_fsm_tracked_data[state] = [r]

    def push_snmp_tracking(self, oid, tlv):
        self.logger.debug("PUSH SNMP %s: %r", oid, tlv)

    def iter_cli_tracking(self):
        """
        Yields command, packets for collected data
        :return:
        """
        for cmd in self.cli_tracked_data:
            self.logger.debug("Collecting %d tracked CLI items", len(self.cli_tracked_data[cmd]))
            yield cmd, self.cli_tracked_data[cmd]
        self.cli_tracked_data = {}

    def iter_cli_fsm_tracking(self):
        for state in self.cli_fsm_tracked_data:
            yield state, self.cli_fsm_tracked_data[state]

    def request_beef(self):
        """
        Download and return beef
        :return:
        """
        if not hasattr(self, "_beef"):
            self.logger.debug("Requesting beef")
            beef_storage_url = self.credentials.get("beef_storage_url")
            beef_path = self.credentials.get("beef_path")
            if not beef_storage_url:
                self.logger.debug("No storage URL")
                self._beef = None
                return None
            if not beef_path:
                self.logger.debug("No beef path")
                self._beef = None
                return None
            from .beef import Beef

            beef = Beef.load(beef_storage_url, beef_path)
            self._beef = beef
        return self._beef

    @property
    def is_beefed(self):
        return self.credentials.get("cli_protocol") == "beef"
コード例 #4
0
ファイル: base.py プロジェクト: nbashev/noc
class BaseCLI(object):
    name = "base"

    def __init__(self, script, tos: Optional[int] = None):
        self.script = script
        self.profile = script.profile
        self.logger = PrefixLoggerAdapter(self.script.logger, self.name)
        self.stream: Optional[BaseStream] = None
        self.tos = tos
        self.is_started = False
        # Current error to raise on TimeoutError
        self.timeout_exception_cls = CLIConnectionReset

    def close(self):
        self.script.close_current_session()
        self.close_stream()

    def close_stream(self):
        if self.stream:
            self.logger.debug("Closing stream")
            if self.is_started and self.profile.command_exit:
                with IOLoopContext(suppress_trace=True) as loop:
                    loop.run_until_complete(
                        self.send(smart_bytes(self.profile.command_exit)))
            self.stream.close()
            self.stream = None

    def is_closed(self):
        return not self.stream

    def set_script(self, script):
        self.script = script
        self.logger = PrefixLoggerAdapter(self.script.logger, self.name)

    def shutdown_session(self):
        raise NotImplementedError

    def set_timeout(self,
                    timeout: Optional[float] = None,
                    error: Optional[Type[Exception]] = None) -> None:
        if timeout:
            error = error or CLIConnectionReset
            self.logger.debug("Setting timeout: %ss, error=%s", timeout,
                              error.__name__)
            self.timeout_exception_cls = error
            self.stream.set_timeout(timeout)
        else:
            self.logger.debug("Resetting timeouts")
            self.timeout_exception_cls = CLIConnectionReset
            self.stream.set_timeout(None)

    def get_stream(self) -> "BaseStream":
        """
        Stream factory. Must be overriden in subclasses.
        :return:
        """
        raise NotImplementedError

    async def start_stream(self):
        self.stream = self.get_stream()
        address = self.script.credentials.get("address")
        self.logger.debug("Connecting %s", address)
        try:
            metrics["cli_connection", ("proto", self.name)] += 1
            await self.stream.connect(address,
                                      self.script.credentials.get("cli_port"))
            metrics["cli_connection_success", ("proto", self.name)] += 1
        except ConnectionRefusedError:
            self.logger.debug("Connection refused")
            metrics["cli_connection_refused", ("proto", self.name)] += 1
            raise ConnectionRefusedError
        self.logger.debug("Connected")
        await self.stream.startup()

    async def send(self, cmd: bytes) -> None:
        raise NotImplementedError
コード例 #5
0
class RTSPBase(object):
    name = "rtsp"
    iostream_class = TelnetIOStream
    default_port = 554
    BUFFER_SIZE = config.activator.buffer_size
    MATCH_TAIL = 256
    # Retries on immediate disconnect
    CONNECT_RETRIES = config.activator.connect_retries
    # Timeout after immediate disconnect
    CONNECT_TIMEOUT = config.activator.connect_timeout
    # compiled capabilities
    HAS_TCP_KEEPALIVE = hasattr(socket, "SO_KEEPALIVE")
    HAS_TCP_KEEPIDLE = hasattr(socket, "TCP_KEEPIDLE")
    HAS_TCP_KEEPINTVL = hasattr(socket, "TCP_KEEPINTVL")
    HAS_TCP_KEEPCNT = hasattr(socket, "TCP_KEEPCNT")
    HAS_TCP_NODELAY = hasattr(socket, "TCP_NODELAY")
    # Time until sending first keepalive probe
    KEEP_IDLE = 10
    # Keepalive packets interval
    KEEP_INTVL = 10
    # Terminate connection after N keepalive failures
    KEEP_CNT = 3

    def __init__(self, script, tos=None):
        self.script = script
        self.profile = script.profile
        self.logger = PrefixLoggerAdapter(self.script.logger, self.name)
        self.iostream = None
        self.ioloop = None
        self.path = None
        self.cseq = 1
        self.method = None
        self.headers = None
        self.auth = None
        self.buffer = ""
        self.is_started = False
        self.result = None
        self.error = None
        self.is_closed = False
        self.close_timeout = None
        self.current_timeout = None
        self.tos = tos
        self.rx_rtsp_end = "\r\n\r\n"

    def close(self):
        self.script.close_current_session()
        self.close_iostream()
        if self.ioloop:
            self.logger.debug("Closing IOLoop")
            self.ioloop.close(all_fds=True)
            self.ioloop = None
        self.is_closed = True

    def close_iostream(self):
        if self.iostream:
            self.iostream.close()

    def deferred_close(self, session_timeout):
        if self.is_closed or not self.iostream:
            return
        self.logger.debug("Setting close timeout to %ss", session_timeout)
        # Cannot call call_later directly due to
        # thread-safety problems
        # See tornado issue #1773
        tornado.ioloop.IOLoop.instance().add_callback(self._set_close_timeout, session_timeout)

    def _set_close_timeout(self, session_timeout):
        """
        Wrapper to deal with IOLoop.add_timeout thread safety problem
        :param session_timeout:
        :return:
        """
        self.close_timeout = tornado.ioloop.IOLoop.instance().call_later(
            session_timeout, self.close
        )

    def create_iostream(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        if self.tos:
            s.setsockopt(socket.IPPROTO_IP, socket.IP_TOS, self.tos)
        if self.HAS_TCP_NODELAY:
            s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
        if self.HAS_TCP_KEEPALIVE:
            s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
            if self.HAS_TCP_KEEPIDLE:
                s.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE, self.KEEP_IDLE)
            if self.HAS_TCP_KEEPINTVL:
                s.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL, self.KEEP_INTVL)
            if self.HAS_TCP_KEEPCNT:
                s.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT, self.KEEP_CNT)
        return self.iostream_class(s, self)

    def set_timeout(self, timeout):
        if timeout:
            self.logger.debug("Setting timeout: %ss", timeout)
            self.current_timeout = datetime.timedelta(seconds=timeout)
        else:
            if self.current_timeout:
                self.logger.debug("Resetting timeouts")
            self.current_timeout = None

    def set_script(self, script):
        self.script = script
        if self.close_timeout:
            tornado.ioloop.IOLoop.instance().remove_timeout(self.close_timeout)
            self.close_timeout = None

    def get_uri(self, port=None):
        address = self.script.credentials.get("address")
        if not port:
            port = self.default_port
        if port:
            address += ":%s" % port
        uri = "rtsp://%s%s" % (address, self.path)
        return uri.encode("utf-8")

    @tornado.gen.coroutine
    def send(self, method=None, body=None):
        # @todo: Apply encoding
        self.error = None
        body = body or ""
        method = method or self.method
        h = {
            # "Host": str(u.netloc),
            # "Connection": "close",
            "CSeq": self.cseq,
            "User-Agent": DEFAULT_USER_AGENT,
        }
        if self.auth:
            h["Authorization"] = self.auth.build_digest_header(
                self.get_uri(), method, self.headers["WWW-Authenticate"]["Digest"]
            )
        req = b"%s %s %s\r\n%s\r\n\r\n%s" % (
            method,
            self.get_uri(),
            DEFAULT_PROTOCOL,
            "\r\n".join(b"%s: %s" % (k, h[k]) for k in h),
            body,
        )

        self.logger.debug("Send: %r", req)
        yield self.iostream.write(req)
        self.cseq += 1

    @tornado.gen.coroutine
    def submit(self):
        # Create iostream and connect, when necessary
        if not self.iostream:
            self.iostream = self.create_iostream()
            address = (self.script.credentials.get("address"), self.default_port)
            self.logger.debug("Connecting %s", address)
            try:
                yield self.iostream.connect(address)
            except tornado.iostream.StreamClosedError:
                self.logger.debug("Connection refused")
                self.error = RTSPConnectionRefused("Connection refused")
                raise tornado.gen.Return(None)
            self.logger.debug("Connected")
            yield self.iostream.startup()
        # Perform all necessary login procedures
        if not self.is_started:
            self.is_started = True
            yield self.send("OPTIONS")
            yield self.get_rtsp_response()
            if self.error and self.error.code == 401:
                self.logger.info("Authentication needed")
                self.auth = DigestAuth(
                    user=self.script.credentials.get("user"),
                    password=self.script.credentials.get("password"),
                )
                # Send command
        yield self.send()
        r = yield self.get_rtsp_response()
        raise tornado.gen.Return(r)

    @tornado.gen.coroutine
    def get_rtsp_response(self):
        result = []
        header_sep = "\r\n\r\n"
        while True:
            r = yield self.read_until_end()
            # r = r.strip()
            # Process header
            if header_sep not in r:
                self.result = ""
                self.error = RTSPBadResponse("Missed header separator")
                raise tornado.gen.Return(None)
            header, r = r.split(header_sep, 1)
            code, msg, headers = self.parse_rtsp_header(header)
            self.headers = headers
            self.logger.debug(
                "Parsed received, err code: %d, err message: %s, headers: %s", code, msg, headers
            )
            if code == 401:
                self.result = ""
                self.error = RTSPAuthFailed("%s (code=%s)" % (msg, code), code=int(code))
                raise tornado.gen.Return(None)
            if not 200 <= code <= 299:
                # RTSP Error
                self.result = ""
                self.error = RTSPError("%s (code=%s)" % (msg, code), code=int(code))
                raise tornado.gen.Return(None)
            result += [r]
            break
        self.result = "".join(result)
        raise tornado.gen.Return(self.result)

    @staticmethod
    def parse_rtsp_header(data):
        code, msg, headers = 200, "", {}
        for line in data.splitlines():
            if line.startswith("RTSP/1.0"):
                _, code, msg = line.split(None, 2)
            elif ":" in line:
                h, v = line.split(":", 1)
                if h in MULTIPLE_HEADER:
                    if h not in headers:
                        headers[h] = {}
                    auth, line = v.split(None, 1)
                    items = parse_http_list(line)
                    headers[h][auth] = parse_keqv_list(items)
                    continue
                headers[h] = v.strip()
        return int(code), msg, headers

    def execute(self, path, method, **kwargs):
        """
        Perform request and return result
        :param path:
        :param method:
        :param kwargs:
        :return:
        """
        if self.close_timeout:
            self.logger.debug("Removing close timeout")
            self.ioloop.remove_timeout(self.close_timeout)
            self.close_timeout = None
        self.buffer = ""
        # self.command = self.profile.get_mml_command(cmd, **kwargs)
        self.path = path
        self.method = method
        self.error = None
        if not self.ioloop:
            self.logger.debug("Creating IOLoop")
            self.ioloop = tornado.ioloop.IOLoop()
        with Span(
            server=self.script.credentials.get("address"), service=self.name, in_label=self.method
        ) as s:
            self.ioloop.run_sync(self.submit)
            if self.error:
                if s:
                    s.error_text = str(self.error)
                raise self.error
            else:
                return self.result

    @tornado.gen.coroutine
    def read_until_end(self):
        connect_retries = self.CONNECT_RETRIES
        while True:
            try:
                f = self.iostream.read_bytes(self.BUFFER_SIZE, partial=True)
                if self.current_timeout:
                    r = yield tornado.gen.with_timeout(self.current_timeout, f)
                else:
                    r = yield f
            except tornado.iostream.StreamClosedError:
                # Check if remote end closes connection just
                # after connection established
                if not self.is_started and connect_retries:
                    self.logger.info(
                        "Connection reset. %d retries left. Waiting %d seconds",
                        connect_retries,
                        self.CONNECT_TIMEOUT,
                    )
                    while connect_retries:
                        yield tornado.gen.sleep(self.CONNECT_TIMEOUT)
                        connect_retries -= 1
                        self.iostream = self.create_iostream()
                        address = (
                            self.script.credentials.get("address"),
                            self.script.credentials.get("cli_port", self.default_port),
                        )
                        self.logger.debug("Connecting %s", address)
                        try:
                            yield self.iostream.connect(address)
                            break
                        except tornado.iostream.StreamClosedError:
                            if not connect_retries:
                                raise tornado.iostream.StreamClosedError()
                    continue
                else:
                    raise tornado.iostream.StreamClosedError()
            except tornado.gen.TimeoutError:
                self.logger.info("Timeout error")
                raise tornado.gen.TimeoutError("Timeout")
            self.logger.debug("Received: %r", r)
            self.buffer += r
            # offset = max(0, len(self.buffer) - self.MATCH_TAIL)
            # match = self.rx_mml_end.search(self.buffer, offset)
            # if match:
            #     self.logger.debug("End of the block")
            #     r = self.buffer[:match.start()]
            #     self.buffer = self.buffer[match.end()]
            raise tornado.gen.Return(r)

    def shutdown_session(self):
        if self.profile.shutdown_session:
            self.logger.debug("Shutdown session")
            self.profile.shutdown_session(self.script)
コード例 #6
0
class SNMP(object):
    name = "snmp"

    class TimeOutError(NOCError):
        default_code = ERR_SNMP_TIMEOUT
        default_msg = "SNMP Timeout"

    class FatalTimeoutError(NOCError):
        default_code = ERR_SNMP_FATAL_TIMEOUT
        default_msg = "Fatal SNMP Timeout"

    SNMPError = SNMPError

    def __init__(self, script):
        self._script = weakref.ref(script)
        self.ioloop = None
        self.result = None
        self.logger = PrefixLoggerAdapter(script.logger, self.name)
        self.timeouts_limit = 0
        self.timeouts = 0
        self.socket = None

    @property
    def script(self):
        return self._script()

    def set_timeout_limits(self, n):
        """
        Set sequental timeouts l
        :param n:
        :return:
        """
        self.timeouts_limit = n
        self.timeouts = n

    def close(self):
        if self.socket:
            self.logger.debug("Closing UDP socket")
            self.socket.close()
            self.socket = None
        if self.ioloop:
            self.logger.debug("Closing IOLoop")
            self.ioloop.close(all_fds=True)
            self.ioloop = None

    def get_ioloop(self):
        if not self.ioloop:
            self.logger.debug("Creating IOLoop")
            self.ioloop = tornado.ioloop.IOLoop()
        return self.ioloop

    def get_socket(self):
        if not self.socket:
            self.logger.debug("Create UDP socket")
            self.socket = UDPSocket(ioloop=self.get_ioloop(),
                                    tos=self.script.tos)
        return self.socket

    def _get_snmp_version(self, version=None):
        if version is not None:
            return version
        if self.script.has_snmp_v2c():
            return SNMP_v2c
        elif self.script.has_snmp_v3():
            return SNMP_v3
        elif self.script.has_snmp_v1():
            return SNMP_v1
        return SNMP_v2c

    def get(self, oids, cached=False, version=None, raw_varbinds=False):
        """
        Perform SNMP GET request
        :param oid: string or list of oids
        :param cached: True if get results can be cached during session
        :param raw_varbinds: Return value in BER encoding
        :returns: eigther result scalar or dict of name -> value
        """
        @tornado.gen.coroutine
        def run():
            try:
                self.result = yield snmp_get(
                    address=self.script.credentials["address"],
                    oids=oids,
                    community=str(self.script.credentials["snmp_ro"]),
                    tos=self.script.tos,
                    ioloop=self.get_ioloop(),
                    udp_socket=self.get_socket(),
                    version=version,
                    raw_varbinds=raw_varbinds)
                self.timeouts = self.timeouts_limit
            except SNMPError as e:
                if e.code == TIMED_OUT:
                    if self.timeouts_limit:
                        self.timeouts -= 1
                        if not self.timeouts:
                            raise self.FatalTimeoutError()
                    raise self.TimeOutError()
                else:
                    raise

        version = self._get_snmp_version(version)
        self.get_ioloop().run_sync(run)
        r, self.result = self.result, None
        return r

    def set(self, *args):
        """
        Perform SNMP GET request
        :param oid: string or list of oids
        :returns: eigther result scalar or dict of name -> value
        """
        @tornado.gen.coroutine
        def run():
            try:
                self.result = yield snmp_set(
                    address=self.script.credentials["address"],
                    varbinds=varbinds,
                    community=str(self.script.credentials["snmp_rw"]),
                    tos=self.script.tos,
                    ioloop=self.get_ioloop(),
                    udp_socket=self.get_socket())
            except SNMPError as e:
                if e.code == TIMED_OUT:
                    raise self.TimeOutError()
                else:
                    raise

        if len(args) == 1:
            varbinds = args
        elif len(args) == 2:
            varbinds = [(args[0], args[1])]
        else:
            raise ValueError("Invalid varbinds")
        self.get_ioloop().run_sync(run)
        r, self.result = self.result, None
        return r

    def count(self, oid, filter=None, version=None):
        """
        Iterate MIB subtree and count matching instances
        :param oid: OID
        :param filter: Callable accepting oid and value and returning boolean
        """
        @tornado.gen.coroutine
        def run():
            try:
                self.result = yield snmp_count(
                    address=self.script.credentials["address"],
                    oid=oid,
                    community=str(self.script.credentials["snmp_ro"]),
                    bulk=self.script.has_snmp_bulk(),
                    filter=filter,
                    tos=self.script.tos,
                    ioloop=self.get_ioloop(),
                    udp_socket=self.get_socket(),
                    version=version)
            except SNMPError as e:
                if e.code == TIMED_OUT:
                    raise self.TimeOutError()
                else:
                    raise

        version = self._get_snmp_version(version)
        self.get_ioloop().run_sync(run)
        r, self.result = self.result, None
        return r

    def getnext(self,
                oid,
                community_suffix=None,
                filter=None,
                cached=False,
                only_first=False,
                bulk=None,
                max_repetitions=None,
                version=None,
                max_retries=0,
                timeout=10,
                raw_varbinds=False):
        @tornado.gen.coroutine
        def run():
            try:
                self.result = yield snmp_getnext(
                    address=self.script.credentials["address"],
                    oid=oid,
                    community=str(self.script.credentials["snmp_ro"]),
                    bulk=self.script.has_snmp_bulk() if bulk is None else bulk,
                    max_repetitions=max_repetitions,
                    filter=filter,
                    only_first=only_first,
                    tos=self.script.tos,
                    ioloop=self.get_ioloop(),
                    udp_socket=self.get_socket(),
                    version=version,
                    max_retries=max_retries,
                    timeout=timeout,
                    raw_varbinds=raw_varbinds)
            except SNMPError as e:
                if e.code == TIMED_OUT:
                    raise self.TimeOutError()
                else:
                    raise

        version = self._get_snmp_version(version)
        self.get_ioloop().run_sync(run)
        r, self.result = self.result, None
        return r

    def get_table(self, oid, community_suffix=None, cached=False):
        """
        GETNEXT wrapper. Returns a hash of <index> -> <value>
        """
        r = {}
        for o, v in self.getnext(oid,
                                 community_suffix=community_suffix,
                                 cached=cached):
            r[int(o.split(".")[-1])] = v
        return r

    def join_tables(self, oid1, oid2, community_suffix=None, cached=False):
        """
        Generator returning a rows of two snmp tables joined by index
        """
        t1 = self.get_table(oid1,
                            community_suffix=community_suffix,
                            cached=cached)
        t2 = self.get_table(oid2,
                            community_suffix=community_suffix,
                            cached=cached)
        for k1, v1 in t1.items():
            try:
                yield (v1, t2[k1])
            except KeyError:
                pass

    def get_tables(self,
                   oids,
                   community_suffix=None,
                   bulk=False,
                   min_index=None,
                   max_index=None,
                   cached=False,
                   max_retries=0):
        """
        Query list of SNMP tables referenced by oids and yields
        tuples of (key, value1, ..., valueN)

        :param oids: List of OIDs
        :param community_suffix: Optional suffix to be added to community
        :param bulk: Use BULKGETNEXT if true
        :param min_index:
        :param max_index:
        :param cached:
        :param max_retries:
        :return:
        """
        def gen_table(oid):
            line = len(oid) + 1
            for o, v in self.getnext(oid,
                                     community_suffix=community_suffix,
                                     cached=cached,
                                     bulk=bulk,
                                     max_retries=max_retries):
                yield tuple([int(x) for x in o[line:].split(".")]), v

        # Retrieve tables
        tables = [dict(gen_table(oid)) for oid in oids]
        # Generate index
        index = set()
        for t in tables:
            index.update(t)
        # Yield result
        for i in sorted(index):
            yield [".".join([str(x) for x in i])] + [t.get(i) for t in tables]

    def join(self, oids, community_suffix=None, cached=False, join="left"):
        """
        Query list of tables, merge by oid index
        Tables are records of:
        * <oid>.<index> = value

        join may be one of:
        * left
        * inner
        * outer

        Yield records of (<index>, <value1>, ..., <valueN>)
        """
        tables = [
            self.get_table(o, community_suffix=community_suffix, cached=cached)
            for o in oids
        ]
        if join == "left":
            lt = tables[1:]
            for k in sorted(tables[0]):
                yield tuple([k, tables[0][k]] + [t.get(k) for t in lt])
        elif join == "inner":
            keys = set(tables[0])
            for lt in tables[1:]:
                keys &= set(lt)
            for k in sorted(keys):
                yield tuple([k] + [t.get(k) for t in tables])
        elif join == "outer":
            keys = set(tables[0])
            for lt in tables[1:]:
                keys |= set(lt)
            for k in sorted(keys):
                yield tuple([k] + [t.get(k) for t in tables])

    def get_chunked(self, oids, chunk_size=20, timeout_limits=3):
        """
        Fetch list of oids splitting to several operations when necessary

        :param oids: List of oids
        :param chunk_size: Maximal GET chunk size
        :param timeout_limits: SNMP timeout limits
        :return: dict of oid -> value for all retrieved values
        """
        results = {}
        self.set_timeout_limits(timeout_limits)
        while oids:
            chunk, oids = oids[:chunk_size], oids[chunk_size:]
            chunk = dict((x, x) for x in chunk)
            try:
                results.update(self.get(chunk))
            except self.TimeOutError as e:
                self.logger.error("Failed to get SNMP OIDs %s: %s", oids, e)
            except self.FatalTimeoutError:
                self.logger.error("Fatal timeout error on: %s", oids)
                break
            except self.SNMPError as e:
                self.logger.error("SNMP error code %s", e.code)
        return results
コード例 #7
0
class RPCProxy(object):
    """
    API Proxy
    """
    RPCError = RPCError

    def __init__(self, service, service_name, sync=False, hints=None):
        self._logger = PrefixLoggerAdapter(logger, service_name)
        self._service = service
        self._service_name = service_name
        self._api = service_name.split("-")[0]
        self._tid = itertools.count()
        self._transactions = {}
        self._hints = hints
        self._sync = sync

    def __getattr__(self, item):
        @tornado.gen.coroutine
        def _call(method, *args, **kwargs):
            @tornado.gen.coroutine
            def make_call(url, body, limit=3):
                req_headers = {
                    "X-NOC-Calling-Service": self._service.name,
                    "Content-Type": "text/json"
                }
                sample = 1 if span_ctx and span_id else 0
                with Span(server=self._service_name,
                          service=method,
                          sample=sample,
                          context=span_ctx,
                          parent=span_id) as span:
                    if sample:
                        req_headers["X-NOC-Span-Ctx"] = span.span_context
                        req_headers["X-NOC-Span"] = span.span_id
                    code, headers, data = yield fetch(
                        url,
                        method="POST",
                        headers=req_headers,
                        body=body,
                        connect_timeout=CONNECT_TIMEOUT,
                        request_timeout=REQUEST_TIMEOUT)
                    # Process response
                    if code == 200:
                        raise tornado.gen.Return(data)
                    elif code == 307:
                        # Process redirect
                        if not limit:
                            raise RPCException("Redirects limit exceeded")
                        url = headers.get("location")
                        self._logger.debug("Redirecting to %s", url)
                        r = yield make_call(url, data, limit - 1)
                        raise tornado.gen.Return(r)
                    elif code in (598, 599):
                        span.error_code = code
                        self._logger.debug("Timed out")
                        raise tornado.gen.Return(None)
                    else:
                        span.error_code = code
                        raise RPCHTTPError("HTTP Error %s: %s" % (code, body))

            t0 = time.time()
            self._logger.debug("[%sCALL>] %s.%s(%s, %s)",
                               "SYNC " if self._sync else "",
                               self._service_name, method, args, kwargs)
            metrics["rpc_call", ("called_service", self._service_name),
                    ("method", method)] += 1
            tid = next(self._tid)
            msg = {"method": method, "params": list(args)}
            is_notify = "_notify" in kwargs
            if not is_notify:
                msg["id"] = tid
            body = ujson.dumps(msg)
            # Get services
            response = None
            for t in self._service.iter_rpc_retry_timeout():
                # Resolve service against service catalog
                if self._hints:
                    svc = random.choice(self._hints)
                else:
                    svc = yield self._service.dcs.resolve(self._service_name)
                response = yield make_call(
                    "http://%s/api/%s/" % (svc, self._api), body)
                if response:
                    break
                else:
                    yield tornado.gen.sleep(t)
            t = time.time() - t0
            self._logger.debug("[CALL<] %s.%s (%.2fms)", self._service_name,
                               method, t * 1000)
            if response:
                if not is_notify:
                    try:
                        result = ujson.loads(response)
                    except ValueError as e:
                        raise RPCHTTPError("Cannot decode json: %s" % e)
                    if result.get("error"):
                        self._logger.error("RPC call failed: %s",
                                           result["error"])
                        raise RPCRemoteError(
                            "RPC call failed: %s" % result["error"],
                            remote_code=result.get("code", None))
                    else:
                        raise tornado.gen.Return(result["result"])
                else:
                    # Notifications return None
                    raise tornado.gen.Return()
            else:
                raise RPCNoService("No active service %s found" %
                                   self._service_name)

        @tornado.gen.coroutine
        def async_wrapper(*args, **kwargs):
            result = yield _call(item, *args, **kwargs)
            raise tornado.gen.Return(result)

        def sync_wrapper(*args, **kwargs):
            @tornado.gen.coroutine
            def _sync_call():
                try:
                    r = yield _call(item, *args, **kwargs)
                    result.append(r)
                except tornado.gen.Return as e:
                    result.append(e.value)
                except Exception:
                    error.append(sys.exc_info())
                finally:
                    ev.set()

            ev = threading.Event()
            result = []
            error = []
            self._service.ioloop.add_callback(_sync_call)
            ev.wait()
            if error:
                six.reraise(*error[0])
            else:
                return result[0]

        if item.startswith("_"):
            return self.__dict__[item]
        span_ctx, span_id = get_current_span()
        if self._sync:
            return sync_wrapper
        else:
            return async_wrapper
コード例 #8
0
ファイル: base.py プロジェクト: nbashev/noc
class BaseLoader(object):
    """
    Import directory structure:
    var/
        import/
            <system name>/
                <loader name>/
                    import.jsonl[.ext]  -- state to load, must have .ext extension
                                         according to selected compressor
                    mappings.csv -- ID mappings
                    archive/
                        import-YYYY-MM-DD-HH-MM-SS.jsonl.ext -- imported state

    Import file format: CSV, unix end of lines, UTF-8, comma-separated
    First column - record id in the terms of connected system,
    other columns must be defined in *fields* variable.

    File must be sorted by first field either as string or as numbers,
    sort order must not be changed.

    mappings.csv - CSV, unix end of lines, UTF-8 comma separated
    mappings of ID between NOC and remote system. Populated by loader
    automatically.

    :param fields: List of either field names or tuple of
        (field name, related loader name)
    """

    # Loader name
    name: str
    # Loader model (Database)
    model = None
    # Data model
    data_model: BaseModel

    # List of tags to add to the created records
    tags = []

    rx_archive = re.compile(r"^import-\d{4}(?:-\d{2}){5}.jsonl%s$" %
                            compressor.ext.replace(".", r"\."))

    # Discard records which cannot be dereferenced
    discard_deferred = False
    # Ignore auto-generated unique fields
    ignore_unique = {"bi_id"}

    REPORT_INTERVAL = 1000

    class Deferred(Exception):
        pass

    def __init__(self, chain):
        self.chain = chain
        self.system = chain.system
        self.logger = PrefixLoggerAdapter(
            logger, "%s][%s" % (self.system.name, self.name))
        self.disable_mappings = False
        self.import_dir = os.path.join(config.path.etl_import,
                                       self.system.name, self.name)
        self.archive_dir = os.path.join(self.import_dir, "archive")
        self.mappings_path = os.path.join(self.import_dir, "mappings.csv")
        self.mappings = {}
        self.new_state_path = None
        self.c_add = 0
        self.c_change = 0
        self.c_delete = 0
        # Mapped fields
        self.mapped_fields = self.data_model.get_mapped_fields()
        # Build clean map
        self.clean_map = {}  # field name -> clean function
        self.pending_deletes: List[Tuple[str,
                                         BaseModel]] = []  # (id, BaseModel)
        self.referred_errors: List[Tuple[str,
                                         BaseModel]] = []  # (id, BaseModel)
        if self.is_document:
            import mongoengine.errors

            unique_fields = [
                f.name for f in self.model._fields.values()
                if f.unique and f.name not in self.ignore_unique
            ]
            self.integrity_exception = mongoengine.errors.NotUniqueError
        else:
            # Third-party modules
            import django.db.utils

            unique_fields = [
                f.name for f in self.model._meta.fields
                if f.unique and f.name != self.model._meta.pk.name
                and f.name not in self.ignore_unique
            ]
            self.integrity_exception = django.db.utils.IntegrityError
        if unique_fields:
            self.unique_field = unique_fields[0]
        else:
            self.unique_field = None
        self.has_remote_system: bool = hasattr(self.model, "remote_system")

    @property
    def is_document(self):
        """
        Returns True if model is Document, False - if Model
        """
        return hasattr(self.model, "_fields")

    def load_mappings(self):
        """
        Load mappings file
        """
        if self.model:
            if self.is_document:
                self.update_document_clean_map()
            else:
                self.update_model_clean_map()
        if not os.path.exists(self.mappings_path):
            return
        self.logger.info("Loading mappings from %s", self.mappings_path)
        with open(self.mappings_path) as f:
            reader = csv.reader(f)
            for k, v in reader:
                self.mappings[self.clean_str(k)] = v
        self.logger.info("%d mappings restored", len(self.mappings))

    def get_new_state(self) -> Optional[TextIOWrapper]:
        """
        Returns file object of new state, or None when not present
        """
        # Try import.csv
        path = compressor.get_path(
            os.path.join(self.import_dir, "import.jsonl"))
        if not os.path.isfile(path):
            return None
        logger.info("Loading from %s", path)
        self.new_state_path = path
        return compressor(path, "r").open()

    def get_current_state(self) -> TextIOWrapper:
        """
        Returns file object of current state or None
        """
        self.load_mappings()
        if not os.path.isdir(self.archive_dir):
            self.logger.info("Creating archive directory: %s",
                             self.archive_dir)
            try:
                os.mkdir(self.archive_dir)
            except OSError as e:
                self.logger.error("Failed to create directory: %s (%s)",
                                  self.archive_dir, e)
                # @todo: Die
        if os.path.isdir(self.archive_dir):
            fn = list(
                sorted(f for f in os.listdir(self.archive_dir)
                       if self.rx_archive.match(f)))
        else:
            fn = []
        if not fn:
            return StringIO("")
        path = os.path.join(self.archive_dir, fn[-1])
        logger.info("Current state from %s", path)
        return compressor(path, "r").open()

    def iter_jsonl(
            self,
            f: TextIOWrapper,
            data_model: Optional[BaseModel] = None) -> Iterable[BaseModel]:
        """
        Iterate over JSONl stream and yield model instances
        :param f:
        :param data_model:
        :return:
        """
        dm = data_model or self.data_model
        for line in f:
            yield dm.parse_raw(line.replace("\\r", ""))

    def diff(
        self,
        old: Iterable[BaseModel],
        new: Iterable[BaseModel],
        include_fields: Set = None
    ) -> Iterable[Tuple[Optional[BaseModel], Optional[BaseModel]]]:
        """
        Compare old and new CSV files and yield pair of matches
        * old, new -- when changed
        * old, None -- when removed
        * None, new -- when added
        """

        o = next(old, None)
        n = next(new, None)
        while o or n:
            if not o:
                # New
                yield None, n
                n = next(new, None)
            elif not n:
                # Removed
                yield o, None
                o = next(old, None)
            else:
                if n.id == o.id:
                    # Changed
                    if n.dict(include=include_fields) != o.dict(
                            include=include_fields):
                        yield o, n
                    n = next(new, None)
                    o = next(old, None)
                elif n.id < o.id:
                    # Added
                    yield None, n
                    n = next(new, None)
                else:
                    # Removed
                    yield o, None
                    o = next(old, None)

    def load(self):
        """
        Import new data
        """
        self.logger.info("Importing")
        ns = self.get_new_state()
        if not ns:
            self.logger.info("No new state, skipping")
            self.load_mappings()
            return
        current_state = self.iter_jsonl(self.get_current_state())
        new_state = self.iter_jsonl(ns)
        deferred_add = []
        deferred_change = []
        for o, n in self.diff(current_state, new_state):
            if o is None and n:
                try:
                    self.on_add(n)
                except self.Deferred:
                    if not self.discard_deferred:
                        deferred_add += [n]
            elif o and n is None:
                self.on_delete(o)
            else:
                try:
                    self.on_change(o, n)
                except self.Deferred:
                    if not self.discard_deferred:
                        deferred_change += [(o, n)]
            rn = self.c_add + self.c_change + self.c_delete
            if rn > 0 and rn % self.REPORT_INTERVAL == 0:
                self.logger.info("   ... %d records", rn)
        # Add deferred records
        while len(deferred_add):
            nd = []
            for row in deferred_add:
                try:
                    self.on_add(row)
                except self.Deferred:
                    nd += [row]
            if len(nd) == len(deferred_add):
                raise Exception("Unable to defer references")
            deferred_add = nd
            rn = self.c_add + self.c_change + self.c_delete
            if rn % self.REPORT_INTERVAL == 0:
                self.logger.info("   ... %d records", rn)
        # Change deferred records
        while len(deferred_change):
            nd = []
            for o, n in deferred_change:
                try:
                    self.on_change(o, n)
                except self.Deferred:
                    nd += [(o, n)]
            if len(nd) == len(deferred_change):
                raise Exception("Unable to defer references")
            deferred_change = nd
            rn = self.c_add + self.c_change + self.c_delete
            if rn % self.REPORT_INTERVAL == 0:
                self.logger.info("   ... %d records", rn)

    def find_object(self, v: Dict[str, Any]) -> Optional[Any]:
        """
        Find object by remote system/remote id
        :param v:
        :return:
        """
        self.logger.debug("Find object: %s", v)
        if not self.has_remote_system:
            return None
        if not v.get("remote_system") or not v.get("remote_id"):
            self.logger.warning("RS or RID not found")
            return None
        find_query = {
            "remote_system": v.get("remote_system"),
            "remote_id": v.get("remote_id")
        }
        try:
            return self.model.objects.get(**find_query)
        except self.model.MultipleObjectsReturned:
            if self.unique_field:
                find_query[self.unique_field] = v.get(self.unique_field)
                r = self.model.objects.filter(**find_query)
                if not r:
                    r = self.model.objects.filter(**find_query)
                return list(r)[-1]
            raise self.model.MultipleObjectsReturned
        except self.model.DoesNotExist:
            self.logger.debug("Object not found")
            return None

    def create_object(self, v):
        """
        Create object with attributes. Override to save complex
        data structures
        """
        self.logger.debug("Create object")
        for k, nv in v.items():
            if k == "tags":
                # Merge tags
                nv = sorted("%s:%s" % (self.system.name, x) for x in nv)
                v[k] = nv
        o = self.model(**v)
        try:
            o.save()
        except self.integrity_exception as e:
            self.logger.warning("Integrity error: %s", e)
            assert self.unique_field
            if not self.is_document:
                from django.db import connection

                connection._rollback()
            # Fallback to change object
            o = self.model.objects.get(
                **{self.unique_field: v[self.unique_field]})
            for k, nv in v.items():
                setattr(o, k, nv)
            o.save()
        return o

    def change_object(self, object_id: str, v: Dict[str, Any]):
        """
        Change object with attributes
        """
        self.logger.debug("Changed object: %s", v)
        # See: https://code.getnoc.com/noc/noc/merge_requests/49
        try:
            o = self.model.objects.get(pk=object_id)
        except self.model.DoesNotExist:
            self.logger.error("Cannot change %s:%s: Does not exists",
                              self.name, object_id)
            return None
        for k, nv in v.items():
            if k == "tags":
                # Merge tags
                ov = o.tags or []
                nv = sorted([
                    x for x in ov if not (x.startswith(self.system.name + ":")
                                          or x == "remote:deleted")
                ] + ["%s:%s" % (self.system.name, x) for x in nv])
            setattr(o, k, nv)
        o.save()
        return o

    def on_add(self, item: BaseModel) -> None:
        """
        Create new record
        """
        self.logger.debug("Add: %s", item.json())
        v = self.clean(item)
        # @todo: Check record is already exists
        if "id" in v:
            del v["id"]
        o = self.find_object(v)
        if o:
            self.c_change += 1
            # Lost&found object with same remote_id
            self.logger.debug("Lost and Found object")
            vv = {
                "remote_system": v["remote_system"],
                "remote_id": v["remote_id"]
            }
            for fn, nv in v.items():
                if fn in vv:
                    continue
                if getattr(o, fn) != nv:
                    vv[fn] = nv
            self.change_object(o.id, vv)
        else:
            self.c_add += 1
            o = self.create_object(v)
        self.set_mappings(item.id, o.id)

    def on_change(self, o: BaseModel, n: BaseModel):
        """
        Create change record
        """
        self.logger.debug("Change: %s", n.json())
        self.c_change += 1
        nv = self.clean(n)
        changes = {
            "remote_system": nv["remote_system"],
            "remote_id": nv["remote_id"]
        }
        ov = self.clean(o)
        for fn in self.data_model.__fields__:
            if fn == "id":
                continue
            if ov[fn] != nv[fn]:
                self.logger.debug("   %s: %s -> %s", fn, ov[fn], nv[fn])
                changes[fn] = nv[fn]
        if n.id in self.mappings:
            self.change_object(self.mappings[n.id], changes)
        else:
            self.logger.error("Cannot map id '%s'. Skipping.", n.id)

    def on_delete(self, item: BaseModel):
        """
        Delete record
        """
        self.pending_deletes += [(item.id, item)]

    def purge(self):
        """
        Perform pending deletes
        """
        for r_id, msg in reversed(self.pending_deletes):
            self.logger.debug("Delete: %s", msg)
            self.c_delete += 1
            try:
                obj = self.model.objects.get(pk=self.mappings[r_id])
                obj.delete()
            except ValueError as e:  # Referred Error
                self.logger.error("%s", str(e))
                self.referred_errors += [(r_id, msg)]
            except KeyError as e:
                # Undefined mappings
                self.logger.error("%s", str(e))
            except self.model.DoesNotExist:
                pass  # Already deleted
        self.pending_deletes = []

    def save_state(self):
        """
        Save current state
        """
        if not self.new_state_path:
            return
        self.logger.info("Summary: %d new, %d changed, %d removed", self.c_add,
                         self.c_change, self.c_delete)
        self.logger.info("Error delete by referred: %s",
                         "\n".join(b.json() for _, b in self.referred_errors))
        t = time.localtime()
        archive_path = os.path.join(
            self.archive_dir,
            compressor.get_path("import-%04d-%02d-%02d-%02d-%02d-%02d.jsonl" %
                                tuple(t[:6])),
        )
        self.logger.info("Moving %s to %s", self.new_state_path, archive_path)
        if self.new_state_path.endswith(compressor.ext):
            # Simply move the file
            shutil.move(self.new_state_path, archive_path)
        else:
            # Compress the file
            self.logger.info("Compressing")
            with open(self.new_state_path,
                      "r") as s, compressor(archive_path, "w") as d:
                d.write(s.read())
            os.unlink(self.new_state_path)
        self.logger.info("Saving mappings to %s", self.mappings_path)
        mdata = "\n".join("%s,%s" % (k, self.mappings[k])
                          for k in sorted(self.mappings))
        safe_rewrite(self.mappings_path, mdata)

    def clean(self, item: BaseModel) -> Dict[str, Any]:
        """
        Cleanup row and return a dict of field name -> value
        """
        r = {
            k: self.clean_map.get(k, self.clean_any)(v)
            for k, v in item.dict().items()
        }
        # Fill integration fields
        r["remote_system"] = self.system.remote_system
        r["remote_id"] = self.clean_str(item.id)
        return r

    def clean_any(self, value: Any) -> Any:
        return value

    def clean_str(self, value) -> Optional[str]:
        if value:
            if isinstance(value, str):
                return smart_text(value)
            elif not isinstance(value, str):
                return str(value)
            else:
                return value
        else:
            return None

    def clean_map_str(self, mappings, value):
        value = self.clean_str(value)
        if self.disable_mappings and not mappings:
            return value
        elif value:
            try:
                value = mappings[value]
            except KeyError:
                self.logger.warning("Deferred. Unknown map value: %s", value)
                raise self.Deferred
        return value

    def clean_bool(self, value: str) -> Optional[bool]:
        if value == "":
            return None
        try:
            return int(value) != 0
        except ValueError:
            pass
        value = value.lower()
        return value in ("t", "true", "y", "yes")

    def clean_reference(self, mappings, r_model, value):
        if not value:
            return None
        elif self.disable_mappings and not mappings:
            return value
        else:
            # @todo: Get proper mappings
            try:
                value = mappings[value]
            except KeyError:
                self.logger.info("Deferred. Unknown value %s:%s", r_model,
                                 value)
                raise self.Deferred()
            return self.chain.cache[r_model, value]

    def clean_int_reference(self, mappings, r_model, value):
        if not value:
            return None
        elif self.disable_mappings and not mappings:
            return value
        else:
            # @todo: Get proper mappings
            try:
                value = int(mappings[value])
            except KeyError:
                self.logger.info("Deferred. Unknown value %s:%s", r_model,
                                 value)
                raise self.Deferred()
            return self.chain.cache[r_model, value]

    def set_mappings(self, rv, lv):
        self.logger.debug("Set mapping remote: %s, local: %s", rv, lv)
        self.mappings[str(rv)] = str(lv)

    def update_document_clean_map(self):
        from mongoengine.fields import BooleanField, ReferenceField
        from noc.core.mongo.fields import PlainReferenceField, ForeignKeyField

        self.logger.debug("Update Document clean map")
        for fn, ft in self.model._fields.items():
            if fn not in self.data_model.__fields__:
                continue
            if isinstance(ft, BooleanField):
                self.clean_map[fn] = self.clean_bool
            elif isinstance(ft, (PlainReferenceField, ReferenceField)):
                if fn in self.mapped_fields:
                    self.clean_map[fn] = functools.partial(
                        self.clean_reference,
                        self.chain.get_mappings(self.mapped_fields[fn]),
                        ft.document_type,
                    )
            elif isinstance(ft, ForeignKeyField):
                if fn in self.mapped_fields:
                    self.clean_map[fn] = functools.partial(
                        self.clean_int_reference,
                        self.chain.get_mappings(self.mapped_fields[fn]),
                        ft.document_type,
                    )
            elif fn in self.mapped_fields:
                self.clean_map[fn] = functools.partial(
                    self.clean_map_str,
                    self.chain.get_mappings(self.mapped_fields[fn]))

    def update_model_clean_map(self):
        from django.db.models import BooleanField, ForeignKey
        from noc.core.model.fields import DocumentReferenceField

        self.logger.debug("Update Model clean map")
        for f in self.model._meta.fields:
            if f.name not in self.data_model.__fields__:
                continue
            if isinstance(f, BooleanField):
                self.clean_map[f.name] = self.clean_bool
            elif isinstance(f, DocumentReferenceField):
                if f.name in self.mapped_fields:
                    self.clean_map[f.name] = functools.partial(
                        self.clean_reference,
                        self.chain.get_mappings(self.mapped_fields[f.name]),
                        f.document,
                    )
            elif isinstance(f, ForeignKey):
                if f.name in self.mapped_fields:
                    self.clean_map[f.name] = functools.partial(
                        self.clean_reference,
                        self.chain.get_mappings(self.mapped_fields[f.name]),
                        f.remote_field.model,
                    )
            elif f.name in self.mapped_fields:
                self.clean_map[f.name] = functools.partial(
                    self.clean_map_str,
                    self.chain.get_mappings(self.mapped_fields[f.name]))

    def check(self, chain):
        self.logger.info("Checking")
        # Get constraints
        if self.is_document:
            # Document
            required_fields = [
                f.name for f in self.model._fields.values()
                if f.required or f.unique
            ]
            unique_fields = [
                f.name for f in self.model._fields.values() if f.unique
            ]
        else:
            # Model
            required_fields = [
                f.name for f in self.model._meta.fields if not f.blank
            ]
            unique_fields = [
                f.name for f in self.model._meta.fields
                if f.unique and f.name != self.model._meta.pk.name
            ]
        if not required_fields and not unique_fields:
            self.logger.info("Nothing to check, skipping")
            return 0
        self.logger.debug("[%s] Required fields: %s", self.model,
                          required_fields)
        self.logger.debug("[%s] Unique fields: %s", self.model, unique_fields)
        self.logger.debug("[%s] Mapped fields: %s", self.model,
                          self.mapped_fields)
        # Prepare data
        ns = self.get_new_state()
        if not ns:
            self.logger.info("No new state, skipping")
            return 0
        new_state = self.iter_jsonl(ns)
        uv = set()
        m_data = {}  # field_number -> set of mapped ids
        # Load mapped ids
        for f in self.mapped_fields:
            line = chain.get_loader(self.mapped_fields[f])
            ls = line.get_new_state()
            if not ls:
                ls = line.get_current_state()
            ms = self.iter_jsonl(ls, data_model=line.data_model)
            m_data[self.data_model.__fields__[f].name] = set(row.id
                                                             for row in ms)
        # Process data
        n_errors = 0
        for row in new_state:
            row = row.dict()
            lr = len(row)
            # Check required fields
            for f in required_fields:
                if f not in self.data_model.__fields__:
                    continue
                if f not in row:
                    self.logger.error(
                        "ERROR: Required field #(%s) is missed in row: %s",
                        f,
                        # self.fields[i],
                        row,
                    )
                    n_errors += 1
                    continue
            # Check unique fields
            for f in unique_fields:
                if f in self.ignore_unique:
                    continue
                v = row[f]
                if v in uv:
                    self.logger.error(
                        "ERROR: Field #(%s) value is not unique: %s",
                        f,
                        # self.fields[i],
                        row,
                    )
                    n_errors += 1
                else:
                    uv.add(v)
            # Check mapped fields
            for i, f in enumerate(self.mapped_fields):
                if i >= lr:
                    continue
                v = row[f]
                if v and v not in m_data[f]:
                    self.logger.error(
                        "ERROR: Field #%d(%s) == '%s' refers to non-existent record: %s",
                        i,
                        f,
                        row[f],
                        row,
                    )
                    n_errors += 1
        if n_errors:
            self.logger.info("%d errors found", n_errors)
        else:
            self.logger.info("No errors found")
        return n_errors

    def check_diff(self):
        def dump(cmd, row):
            print("%s %s" % (cmd, row.json()))

        print("--- %s.%s" % (self.chain.system.name, self.name))
        ns = self.get_new_state()
        if not ns:
            return
        current_state = self.iter_jsonl(self.get_current_state())
        new_state = self.iter_jsonl(ns)
        for o, n in self.diff(current_state, new_state):
            if o is None and n:
                dump("+", n)
            elif o and n is None:
                dump("-", o)
            else:
                dump("/", o)
                dump("\\", n)

    def check_diff_summary(self):
        i, u, d = 0, 0, 0
        ns = self.get_new_state()
        if not ns:
            return i, u, d
        current_state = self.iter_jsonl(self.get_current_state())
        new_state = self.iter_jsonl(ns)
        for o, n in self.diff(current_state, new_state):
            if o is None and n:
                i += 1
            elif o and n is None:
                d += 1
            else:
                u += 1
        return i, u, d
コード例 #9
0
class HTTP(object):
    HTTPError = HTTPError

    def __init__(self, script):
        self.script = script
        if script:  # For testing purposes
            self.logger = PrefixLoggerAdapter(script.logger, "http")
        self.headers = {}
        self.cookies = None
        self.session_started = False
        self.request_id = 1
        self.session_id = None
        self.request_middleware = None
        if self.script:  # For testing purposes
            self.setup_middleware()

    def get_url(self, path):
        address = self.script.credentials["address"]
        port = self.script.credentials.get("http_port")
        if port:
            address += ":%s" % port
        proto = self.script.credentials.get("http_protocol", "http")
        return "%s://%s%s" % (proto, address, path)

    def get(self,
            path,
            headers=None,
            cached=False,
            json=False,
            eof_mark=None,
            use_basic=False):
        """
        Perform HTTP GET request
        :param path: URI
        :param headers: Dict of additional headers
        :param cached: Cache result
        :param json: Decode json if set to True
        :param eof_mark: Waiting eof_mark in stream for end session (perhaps device return length 0)
        :param use_basic: Use basic authentication
        """
        self.ensure_session()
        self.request_id += 1
        self.logger.debug("GET %s", path)
        if cached:
            cache_key = "get_%s" % path
            r = self.script.root.http_cache.get(cache_key)
            if r is not None:
                self.logger.debug("Use cached result")
                return r
        user, password = None, None
        if use_basic:
            user = self.script.credentials.get("user")
            password = self.script.credentials.get("password")
        # Apply GET middleware
        url = self.get_url(path)
        hdr = self._get_effective_headers(headers)
        if self.request_middleware:
            for mw in self.request_middleware:
                url, _, hdr = mw.process_get(url, "", hdr)
        code, headers, result = fetch_sync(
            url,
            headers=hdr,
            request_timeout=60,
            follow_redirects=True,
            allow_proxy=False,
            validate_cert=False,
            eof_mark=eof_mark,
            user=user,
            password=password,
        )
        if not 200 <= code <= 299:
            raise HTTPError(msg="HTTP Error (%s)" % result[:256], code=code)
        self._process_cookies(headers)
        if json:
            try:
                result = ujson.loads(result)
            except ValueError as e:
                raise HTTPError("Failed to decode JSON: %s", e)
        self.logger.debug("Result: %r", result)
        if cached:
            self.script.root.http_cache[cache_key] = result
        return result

    def post(self,
             path,
             data,
             headers=None,
             cached=False,
             json=False,
             eof_mark=None,
             use_basic=False):
        """
        Perform HTTP GET request
        :param path: URI
        :param headers: Dict of additional headers
        :param cached: Cache result
        :param json: Decode json if set to True
        :param eof_mark: Waiting eof_mark in stream for end session (perhaps device return length 0)
        :param use_basic: Use basic authentication
        """
        self.ensure_session()
        self.request_id += 1
        self.logger.debug("POST %s %s", path, data)
        if cached:
            cache_key = "post_%s" % path
            r = self.script.root.http_cache.get(cache_key)
            if r is not None:
                self.logger.debug("Use cached result")
                return r
        user, password = None, None
        if use_basic:
            user = self.script.credentials.get("user")
            password = self.script.credentials.get("password")
        # Apply POST middleware
        url = self.get_url(path)
        hdr = self._get_effective_headers(headers)
        if self.request_middleware:
            for mw in self.request_middleware:
                url, data, hdr = mw.process_post(url, data, hdr)
        code, headers, result = fetch_sync(
            url,
            method="POST",
            body=data,
            headers=hdr,
            request_timeout=60,
            follow_redirects=True,
            allow_proxy=False,
            validate_cert=False,
            eof_mark=eof_mark,
            user=user,
            password=password,
        )
        if not 200 <= code <= 299:
            raise HTTPError(msg="HTTP Error (%s)" % result[:256], code=code)
        self._process_cookies(headers)
        if json:
            try:
                return ujson.loads(result)
            except ValueError as e:
                raise HTTPError(msg="Failed to decode JSON: %s" % e)
        self.logger.debug("Result: %r", result)
        if cached:
            self.script.root.http_cache[cache_key] = result
        return result

    def close(self):
        if self.session_started:
            self.shutdown_session()

    def _process_cookies(self, headers):
        """
        Process and store cookies from response headers
        :param headers:
        :return:
        """
        cdata = headers.get("Set-Cookie")
        if not cdata:
            return
        if not self.cookies:
            self.cookies = SimpleCookie()
        self.cookies.load(cdata)

    def get_cookie(self, name):
        """
        Get cookie name by value
        :param name:
        :return: Morsel object or None
        """
        if not self.cookies:
            return None
        return self.cookies.get(name)

    def _get_effective_headers(self, headers):
        """
        Append session headers when necessary. Apply effective cookies
        :param headers:
        :return:
        """
        if self.headers:
            if headers:
                headers = headers.copy()
            else:
                headers = {}
            headers.update(self.headers)
        elif not headers and self.cookies:
            headers = {}
        if self.cookies:
            headers["Cookie"] = self.cookies.output(header="").lstrip()
        return headers

    def set_header(self, name, value):
        """
        Set HTTP header to be set with all following requests
        :param name:
        :param value:
        :return:
        """
        self.logger.debug("Set header: %s = %s", name, value)
        self.headers[name] = str(value)

    def set_session_id(self, session_id):
        """
        Set session_id to be reused by middleware
        :param session_id:
        :return: None
        """
        if session_id is not None:
            self.session_id = session_id
        else:
            self.session_id = None

    def ensure_session(self):
        if not self.session_started:
            self.session_started = True
            self.setup_session()

    def setup_session(self):
        if self.script.profile.setup_http_session:
            self.logger.debug("Setup http session")
            self.script.profile.setup_http_session(self.script)

    def shutdown_session(self):
        if self.script.profile.shutdown_http_session:
            self.logger.debug("Shutdown http session")
            self.script.profile.shutdown_http_session(self.script)

    def setup_middleware(self):
        mw_list = self.script.profile.get_http_request_middleware(self.script)
        if not mw_list:
            return
        self.request_middleware = []
        for mw_cfg in mw_list:
            if isinstance(mw_cfg, tuple):
                name, cfg = mw_cfg
            else:
                name, cfg = mw_cfg, {}
            if "." in name:
                # Handler
                mw_cls = get_handler(name)
                assert mw_cls
                assert issubclass(mw_cls, BaseMiddleware)
            else:
                # Middleware name
                mw_cls = loader.get_class(name)
            self.request_middleware += [mw_cls(self, **cfg)]
コード例 #10
0
ファイル: engine.py プロジェクト: gabrielat/noc
class Engine(object):
    ILOCK = threading.Lock()
    AC_POLICY_VIOLATION = None

    def __init__(self, object):
        self.object = object
        self.logger = PrefixLoggerAdapter(logger, self.object.name)
        self.env = None
        self.templates = {}  # fact class -> template
        self.fcls = {}  # template -> Fact class
        self.facts = {}  # Index -> Fact
        self.rn = 0  # Rule number
        self.config = None  # Cached config
        self.interface_ranges = None
        with self.ILOCK:
            self.AC_POLICY_VIOLATION = AlarmClass.objects.filter(
                name="Config | Policy Violation").first()
            if not self.AC_POLICY_VIOLATION:
                logger.error(
                    "Alarm class 'Config | Policy Violation' is not found. Alarms cannot be raised"
                )

    def get_template(self, fact):
        if fact.cls not in self.templates:
            self.logger.debug("Creating template %s", fact.cls)
            self.templates[fact.cls] = self.env.BuildTemplate(
                fact.cls, fact.get_template())
            self.fcls[fact.cls] = fact.__class__
            self.logger.debug("Define template %s",
                              self.templates[fact.cls].PPForm())
        return self.templates[fact.cls]

    def get_rule_number(self):
        return self.rn

    def assert_fact(self, fact):
        f = self.get_template(fact).BuildFact()
        f.AssignSlotDefaults()
        for k, v in fact.iter_factitems():
            if v is None or v == [] or v == tuple():
                continue
            if isinstance(v, six.string_types):
                v = v.replace("\n", "\\n")
            f.Slots[k] = v
        try:
            f.Assert()
        except clips.ClipsError as e:
            self.logger.error("Could not assert: %s", f.PPForm())
            self.logger.error("CLIPS Error: %s\n%s", e,
                              clips.ErrorStream.Read())
            return
        self.facts[f.Index] = fact
        self.logger.debug("Assert %s", f.PPForm())

    def learn(self, gen):
        """
        Learn sequence of facts
        """
        n = 0
        for f in gen:
            if hasattr(f, "managed_object") and f.managed_object is not None:
                f.bind()
                # @todo: Custom bindings from solutions
            self.assert_fact(f)
            n += 1

    def iter_errors(self):
        """
        Generator yielding known errors
        """
        try:
            e = self.templates["error"].InitialFact()
        except TypeError:
            raise StopIteration
        while e:
            if "obj" in e.Slots.keys():
                obj = e.Slots["obj"]
                if hasattr(obj, "Index"):
                    # obj is a fact
                    if obj.Index in self.facts:
                        obj = self.facts[obj.Index]
            else:
                obj = None
            error = Error(e.Slots["type"], obj=obj, msg=e.Slots["msg"])
            if e.Index not in self.facts:
                self.facts[e.Index] = error
            yield error
            e = e.Next()

    def iter_roles(self):
        """
        Generator yielding role fact
        """
        try:
            e = self.templates["role"].InitialFact()
        except TypeError:
            raise StopIteration
        while e:
            role = Error(e.Slots["name"])
            if e.Index not in self.facts:
                self.facts[e.Index] = role
            yield role
            e = e.Next()

    def run(self):
        """
        Run engine round
        :returns: Number of matched rules
        """
        return self.env.Run()

    def add_rule(self, expr):
        self.env.Build(expr)
        self.rn += 1

    def check(self):
        with CLIPSEnv() as env:
            self.setup_env(env)
            self._check()

    def _check(self):
        """
        Perform object configuration check
        """
        self.logger.info("Checking %s", self.object)
        parser = self.object.get_parser()
        self.config = self.object.config.read()
        if not self.config:
            self.logger.error("No config for %s. Giving up", self.object)
            return
        # Parse facts
        self.logger.debug("Parsing facts")
        facts = list(parser.parse(self.config))
        self.logger.debug("%d facts are extracted", len(facts))
        self.interface_ranges = parser.interface_ranges
        self.logger.debug("%d interface sections detected",
                          len(self.interface_ranges))
        # Define default templates
        self.get_template(Error(None))
        self.get_template(Role(None))
        # Learn facts
        self.logger.debug("Learning facts")
        self.learn(facts)
        self.logger.debug("Learning complete")
        # Install rules
        rules = []
        for r in self.get_rules():
            if r.is_applicable():
                self.logger.debug("Using validation rule: %s", r.rule.name)
                try:
                    cfg = r.get_config()
                    r.prepare(**cfg)
                except clips.ClipsError as e:
                    self.logger.error("CLIPS Error: %s\n%s", e,
                                      clips.ErrorStream.Read())
                    continue
                except Exception:
                    error_report()
                    continue
                rules += [(r, cfg)]
        # Run python validators
        for r, cfg in rules:
            r.check(**cfg)
        # Run CLIPS engine
        while True:
            self.logger.debug("Running engine")
            n = self.run()
            self.logger.debug("%d rules matched", n)
            break  # @todo: Check for commands
        # Extract errors
        for e in self.iter_errors():
            self.logger.info("Error found: %s", e)
        # Store object's facts
        self.sync_facts()
        # Manage related alarms
        if self.AC_POLICY_VIOLATION:
            self.sync_alarms()

    def _get_rule_settings(self, ps, scope):
        """
        Process PolicySettings object and returns a list of
        (validator class, config)
        """
        r = []
        for pi in ps.policies:
            policy = pi.policy
            if not pi.is_active or not policy.is_active:
                continue
            for ri in policy.rules:
                if not ri.is_active:
                    continue
                rule = ri.rule
                if rule.is_active and rule.is_applicable_for(self.object):
                    vc = get_handler(rule.handler)
                    if vc and bool(vc.SCOPE & scope):
                        r += [(vc, rule)]
        return r

    def _get_rules(self, model, id, scope, obj=None):
        ps = ValidationPolicySettings.objects.filter(
            model_id=model, object_id=str(id)).first()
        if not ps or not ps.policies:
            return []
        return [
            vc(self, obj, rule.config, scope, rule)
            for vc, rule in self._get_rule_settings(ps, scope)
        ]

    def get_rules(self):
        r = []
        # Object profile rules
        if self.object.object_profile:
            r += self._get_rules(
                "sa.ManagedObjectProfile",
                self.object.object_profile.id,
                BaseValidator.OBJECT,
                self.object,
            )
        # Object rules
        r += self._get_rules("sa.ManagedObject", self.object.id,
                             BaseValidator.OBJECT, self.object)
        # Interface rules
        profile_interfaces = defaultdict(list)
        for i in InvInterface.objects.filter(managed_object=self.object.id):
            if i.profile:
                profile_interfaces[i.profile] += [i]
            r += self._get_rules("inv.Interface", i.id,
                                 BaseValidator.INTERFACE, i)
        # Interface profile rules
        for p in profile_interfaces:
            ps = ValidationPolicySettings.objects.filter(
                model_id="inv.InterfaceProfile", object_id=str(p.id)).first()
            if not ps or not ps.policies:
                continue
            rs = self._get_rule_settings(ps, BaseValidator.INTERFACE)
            if rs:
                for iface in profile_interfaces[p]:
                    r += [
                        vc(self, iface, rule.config, BaseValidator.INTERFACE,
                           rule) for vc, rule in rs
                    ]
        # Subinterface profile rules
        profile_subinterfaces = defaultdict(list)
        for si in InvSubInterface.objects.filter(
                managed_object=self.object.id):
            p = si.get_profile()
            if p:
                profile_subinterfaces[p] += [si]
        for p in profile_subinterfaces:
            ps = ValidationPolicySettings.objects.filter(
                model_id="inv.InterfaceProfile", object_id=str(p.id)).first()
            if not ps or not ps.policies:
                continue
            rs = self._get_rule_settings(ps, BaseValidator.SUBINTERFACE)
            if rs:
                for si in profile_subinterfaces[p]:
                    r += [
                        vc(self, si, rule.config, BaseValidator.SUBINTERFACE,
                           rule) for vc, rule in rs
                    ]
        return r

    def get_fact_uuid(self, fact):
        r = [str(self.object.id), fact.cls
             ] + [str(getattr(fact, n)) for n in fact.ID]
        return uuid.uuid5(uuid.NAMESPACE_URL, "-".join(r))

    def get_fact_attrs(self, fact):
        return dict(fact.iter_factitems())

    def sync_facts(self):
        """
        Retrieve known facts and synchronize with database
        """
        self.logger.debug("Synchronizing facts")
        # Get facts from CLIPS
        self.logger.debug("Extracting facts")
        e_facts = {}  # uuid -> fact
        try:
            f = self.env.InitialFact()
        except clips.ClipsError:
            return  # No facts
        while f:
            if f.Template and f.Template.Name in self.templates:
                self.facts[f.Index] = f
                args = {}
                for k in f.Slots.keys():
                    v = f.Slots[k]
                    if v == clips.Nil:
                        v = None
                    args[str(k)] = v
                fi = self.fcls[f.Template.Name](**args)
                e_facts[self.get_fact_uuid(fi)] = fi
            f = f.Next()
        # Get facts from database
        now = datetime.datetime.now()
        collection = ObjectFact._get_collection()
        bulk = []
        new_facts = set(e_facts)
        for f in collection.find({"object": self.object.id}):
            if f["_id"] in e_facts:
                fact = e_facts[f["_id"]]
                f_attrs = self.get_fact_attrs(fact)
                if f_attrs != f["attrs"]:
                    # Changed facts
                    self.logger.debug("Fact %s has been changed: %s -> %s",
                                      f["_id"], f["attrs"], f_attrs)
                    bulk += [
                        UpdateOne(
                            {"_id": f["_id"]},
                            {
                                "$set": {
                                    "attrs": f_attrs,
                                    "changed": now,
                                    "label": smart_text(fact)
                                }
                            },
                        )
                    ]
                new_facts.remove(f["_id"])
            else:
                # Removed fact
                self.logger.debug("Fact %s has been removed", f["_id"])
                bulk += [DeleteOne({"_id": f["_id"]})]
        # New facts
        for f in new_facts:
            fact = e_facts[f]
            f_attrs = self.get_fact_attrs(fact)
            self.logger.debug("Creating fact %s: %s", f, f_attrs)
            bulk += [
                InsertOne({
                    "_id": f,
                    "object": self.object.id,
                    "cls": fact.cls,
                    "label": smart_text(fact),
                    "attrs": f_attrs,
                    "introduced": now,
                    "changed": now,
                })
            ]
        if bulk:
            self.logger.debug("Commiting changes to database")
            try:
                collection.bulk_write(bulk)
                self.logger.debug("Database has been synced")
            except BulkWriteError as e:
                self.logger.error("Bulk write error: '%s'", e.details)
                self.logger.error("Stopping check")
        else:
            self.logger.debug("Nothing changed")

    def compile_query(self, **kwargs):
        def wrap(x):
            for k in kwargs:
                if getattr(x, k, None) != kwargs[k]:
                    return False
            return True

        return wrap

    def find(self, **kwargs):
        """
        Search facts for match. Returns a list of matching facts
        """
        q = self.compile_query(**kwargs)
        return [f for f in six.itervalues(self.facts) if q(f)]

    def find_one(self, **kwargs):
        """
        Search for first matching fact. Returns fact or None
        """
        q = self.compile_query(**kwargs)
        for f in six.itervalues(self.facts):
            if q(f):
                return f
        return None

    def sync_alarms(self):
        """
        Raise/close related alarms
        """
        # Check errors are exists
        n_errors = sum(1 for e in self.iter_errors())
        alarm = ActiveAlarm.objects.filter(
            alarm_class=self.AC_POLICY_VIOLATION.id,
            managed_object=self.object.id).first()
        if n_errors:
            if not alarm:
                self.logger.info("Raise alarm")
                # Raise alarm
                alarm = ActiveAlarm(
                    timestamp=datetime.datetime.now(),
                    managed_object=self.object,
                    alarm_class=self.AC_POLICY_VIOLATION,
                    severity=2000,  # WARNING
                )
            # Alarm is already exists
            alarm.log_message("%d errors has been found" % n_errors)
        elif alarm:
            # Clear alarm
            self.logger.info("Clear alarm")
            alarm.clear_alarm("No errors has been registered")

    def setup_env(self, env):
        """
        Install additional CLIPS functions
        """
        logger.debug("Setting up CLIPS environment")
        self.env = env
        # Create wrappers
        logger.debug("Install function: match-re")
        env.BuildFunction("match-re", "?rx ?s",
                          "(return (python-call py-match-re ?rx ?s))")
コード例 #11
0
def wipe(o):
    if not hasattr(o, "id"):
        try:
            o = ManagedObject.objects.get(id=o)
        except ManagedObject.DoesNotExist:
            return True
    log = PrefixLoggerAdapter(logger, str(o.id))
    # Wiping discovery tasks
    log.debug("Wiping discovery tasks")
    for j in [
            ManagedObject.BOX_DISCOVERY_JOB,
            ManagedObject.PERIODIC_DISCOVERY_JOB
    ]:
        Job.remove("discovery", j, key=o.id, pool=o.pool.name)
    # Wiping FM events
    log.debug("Wiping events")
    FailedEvent.objects.filter(managed_object=o.id).delete()
    ActiveEvent.objects.filter(managed_object=o.id).delete()
    ArchivedEvent.objects.filter(managed_object=o.id).delete()
    # Wiping alarms
    log.debug("Wiping alarms")
    for ac in (ActiveAlarm, ArchivedAlarm):
        for a in ac.objects.filter(managed_object=o.id):
            # Relink root causes
            my_root = a.root
            for iac in (ActiveAlarm, ArchivedAlarm):
                for ia in iac.objects.filter(root=a.id):
                    ia.root = my_root
                    ia.save()
            # Delete alarm
            a.delete()
    # Wiping MAC DB
    log.debug("Wiping MAC DB")
    MACDB._get_collection().remove({"managed_object": o.id})
    # Wiping discovery id cache
    log.debug("Wiping discovery id")
    DiscoveryID._get_collection().remove({"object": o.id})
    # Wiping interfaces, subs and links
    # Wipe links
    log.debug("Wiping links")
    for i in Interface.objects.filter(managed_object=o.id):
        # @todo: Remove aggregated links correctly
        Link.objects.filter(interfaces=i.id).delete()
    #
    log.debug("Wiping subinterfaces")
    SubInterface.objects.filter(managed_object=o.id).delete()
    log.debug("Wiping interfaces")
    Interface.objects.filter(managed_object=o.id).delete()
    log.debug("Wiping forwarding instances")
    ForwardingInstance.objects.filter(managed_object=o.id).delete()
    # Unbind from IPAM
    log.debug("Unbind from IPAM")
    for a in Address.objects.filter(managed_object=o):
        a.managed_object = None
        a.save()
    # Wipe object status
    log.debug("Wiping object status")
    ObjectStatus.objects.filter(object=o.id).delete()
    # Wipe outages
    log.debug("Wiping outages")
    Outage.objects.filter(object=o.id).delete()
    # Wipe uptimes
    log.debug("Wiping uptimes")
    Uptime.objects.filter(object=o.id).delete()
    # Wipe reboots
    log.debug("Wiping reboots")
    Reboot.objects.filter(object=o.id).delete()
    # Delete Managed Object's capabilities
    log.debug("Wiping capabilitites")
    ObjectCapabilities.objects.filter(object=o.id).delete()
    # Delete Managed Object's attributes
    log.debug("Wiping attributes")
    ManagedObjectAttribute.objects.filter(managed_object=o).delete()
    # Finally delete object and config
    log.debug("Finally wiping object")
    o.delete()
    log.debug("Done")
コード例 #12
0
class ProfileChecker(object):
    base_logger = logging.getLogger("profilechecker")
    _rules_cache = cachetools.TTLCache(10, ttl=60)
    _re_cache = {}

    def __init__(
        self,
        address=None,
        pool=None,
        logger=None,
        snmp_community=None,
        calling_service="profilechecker",
        snmp_version=None,
    ):
        self.address = address
        self.pool = pool
        self.logger = PrefixLoggerAdapter(
            logger or self.base_logger, "%s][%s" % (self.pool or "", self.address or "")
        )
        self.result_cache = {}  # (method, param) -> result
        self.error = None
        self.snmp_community = snmp_community
        self.calling_service = calling_service
        self.snmp_version = snmp_version or [SNMP_v2c]
        self.ignoring_snmp = False
        if self.snmp_version is None:
            self.logger.error("SNMP is not supported. Ignoring")
            self.ignoring_snmp = True
        if not self.snmp_community:
            self.logger.error("No SNMP credentials. Ignoring")
            self.ignoring_snmp = True

    def find_profile(self, method, param, result):
        """
        Find profile by method
        :param method: Fingerprint getting method
        :param param: Method params
        :param result: Getting params result
        :return:
        """
        r = defaultdict(list)
        d = self.get_rules()
        for k, value in sorted(six.iteritems(d), key=lambda x: x[0]):
            for v in value:
                r[v] += value[v]
        if (method, param) not in r:
            self.logger.warning("Not find rule for method: %s %s", method, param)
            return
        for match_method, value, action, profile, rname in r[(method, param)]:
            if self.is_match(result, match_method, value):
                self.logger.info("Matched profile: %s (%s)", profile, rname)
                # @todo: process MAYBE rule
                return profile

    def get_profile(self):
        """
        Returns profile for object, or None when not known
        """
        snmp_result = ""
        http_result = ""
        for ruleset in self.iter_rules():
            for (method, param), actions in ruleset:
                try:
                    result = self.do_check(method, param)
                    if not result:
                        continue
                    if "snmp" in method:
                        snmp_result = result
                    if "http" in method:
                        http_result = result
                    for match_method, value, action, profile, rname in actions:
                        if self.is_match(result, match_method, value):
                            self.logger.info("Matched profile: %s (%s)", profile, rname)
                            # @todo: process MAYBE rule
                            return profile
                except NOCError as e:
                    self.logger.error(e.message)
                    self.error = str(e.message)
                    return None
        if snmp_result or http_result:
            self.error = "Not find profile for OID: %s or HTTP string: %s" % (
                snmp_result,
                http_result,
            )
        elif not snmp_result:
            self.error = "Cannot fetch snmp data, check device for SNMP access"
        elif not http_result:
            self.error = "Cannot fetch HTTP data, check device for HTTP access"
        self.logger.info("Cannot detect profile: %s", self.error)
        return None

    def get_error(self):
        """
        Get error message
        :return:
        """
        return self.error

    @classmethod
    @cachetools.cachedmethod(operator.attrgetter("_rules_cache"), lock=lambda _: rules_lock)
    def get_profile_check_rules(cls):
        return list(ProfileCheckRule.objects.all().order_by("preference"))

    def get_rules(self):
        """
        Load ProfileCheckRules and return a list, grouped by preferences
        [{
            (method, param) -> [(
                    match_method,
                    value,
                    action,
                    profile,
                    rule_name
                ), ...]

        }]
        """
        self.logger.info('Compiling "Profile Check rules"')
        d = {}  # preference -> (method, param) -> [rule, ..]
        for r in self.get_profile_check_rules():
            if "snmp" in r.method and self.ignoring_snmp:
                continue
            if r.preference not in d:
                d[r.preference] = {}
            k = (r.method, r.param)
            if k not in d[r.preference]:
                d[r.preference][k] = []
            d[r.preference][k] += [(r.match_method, r.value, r.action, r.profile, r.name)]
        return d

    def iter_rules(self):
        d = self.get_rules()
        for p in sorted(d):
            yield list(six.iteritems(d[p]))

    @classmethod
    @cachetools.cachedmethod(operator.attrgetter("_re_cache"))
    def get_re(cls, regexp):
        return re.compile(regexp)

    def do_check(self, method, param):
        """
        Perform check
        """
        self.logger.debug("do_check(%s, %s)", method, param)
        if (method, param) in self.result_cache:
            self.logger.debug("Using cached value")
            return self.result_cache[method, param]
        h = getattr(self, "check_%s" % method, None)
        if not h:
            self.logger.error("Invalid check method '%s'. Ignoring", method)
            return None
        result = h(param)
        self.result_cache[method, param] = result
        return result

    def check_snmp_v2c_get(self, param):
        """
        Perform SNMP v2c GET. Param is OID or symbolic name
        """
        try:
            param = mib[param]
        except KeyError:
            self.logger.error("Cannot resolve OID '%s'. Ignoring", param)
            return None
        for v in self.snmp_version:
            if v == SNMP_v1:
                r = self.snmp_v1_get(param)
            elif v == SNMP_v2c:
                r = self.snmp_v2c_get(param)
            else:
                raise NOCError(msg="Unsupported SNMP version")
            if r:
                return r

    def check_http_get(self, param):
        """
        Perform HTTP GET check. Param can be URL path or :<port>/<path>
        """
        url = "http://%s%s" % (self.address, param)
        return self.http_get(url)

    def check_https_get(self, param):
        """
        Perform HTTPS GET check. Param can be URL path or :<port>/<path>
        """
        url = "https://%s%s" % (self.address, param)
        return self.https_get(url)

    def is_match(self, result, method, value):
        """
        Returns True when result matches value
        """
        if method == "eq":
            return result == value
        elif method == "contains":
            return value in result
        elif method == "re":
            return bool(self.get_re(value).search(result))
        else:
            self.logger.error("Invalid match method '%s'. Ignoring", method)
            return False

    def snmp_v1_get(self, param):
        """
        Perform SNMP v1 request. May be overridden for testing
        :param param:
        :return:
        """
        self.logger.info("SNMP v1 GET: %s", param)
        try:
            return open_sync_rpc(
                "activator", pool=self.pool, calling_service=self.calling_service
            ).snmp_v1_get(self.address, self.snmp_community, param)
        except RPCError as e:
            self.logger.error("RPC Error: %s", e)
            return None

    def snmp_v2c_get(self, param):
        """
        Perform SNMP v2c request. May be overridden for testing
        :param param:
        :return:
        """
        self.logger.info("SNMP v2c GET: %s", param)
        try:
            return open_sync_rpc(
                "activator", pool=self.pool, calling_service=self.calling_service
            ).snmp_v2c_get(self.address, self.snmp_community, param)
        except RPCError as e:
            self.logger.error("RPC Error: %s", e)
            return None

    def http_get(self, url):
        """
        Perform HTTP request. May be overridden for testing
        :param url: Request URL
        :return:
        """
        self.logger.info("HTTP Request: %s", url)
        try:
            return open_sync_rpc(
                "activator", pool=self.pool, calling_service=self.calling_service
            ).http_get(url, True)
        except RPCError as e:
            self.logger.error("RPC Error: %s", e)
            return None

    def https_get(self, url):
        """
        Perform HTTP request. May be overridden for testing
        :param url: Request URL
        :return:
        """
        return self.http_get(url)
コード例 #13
0
class HTTP(object):
    CONNECT_TIMEOUT = config.http_client.connect_timeout
    REQUEST_TIMEOUT = config.http_client.request_timeout

    class HTTPError(NOCError):
        default_code = ERR_HTTP_UNKNOWN

    def __init__(self, script):
        self.script = script
        self.logger = PrefixLoggerAdapter(script.logger, "http")

    def get_url(self, path):
        address = self.script.credentials["address"]
        port = self.script.credentials.get("http_port")
        if port:
            address += ":%s" % port
        proto = self.script.credentials.get("http_protocol", "http")
        return "%s://%s%s" % (proto, address, path)

    def get(self, path, headers=None, cached=False, json=False, eof_mark=None):
        """
        Perform HTTP GET request
        :param path: URI
        :param headers: Dict of additional headers
        :param cached: Cache result
        :param json: Decode json if set to True
        :param eof_mark: Waiting eof_mark in stream for end session (perhaps device return length 0)
        """
        self.logger.debug("GET %s", path)
        if cached:
            cache_key = "get_%s" % path
            r = self.script.root.http_cache.get(cache_key)
            if r is not None:
                self.logger.debug("Use cached result")
                return r
        code, headers, result = fetch_sync(self.get_url(path),
                                           headers=headers,
                                           request_timeout=60,
                                           follow_redirects=True,
                                           allow_proxy=False,
                                           validate_cert=False,
                                           eof_mark=eof_mark)
        # pylint: disable=superfluous-parens
        if not (200 <= code <= 299):  # noqa
            raise self.HTTPError(msg="HTTP Error (%s)" % result[:256],
                                 code=code)
        if json:
            try:
                result = ujson.loads(result)
            except ValueError as e:
                raise self.HTTPError("Failed to decode JSON: %s", e)
        self.logger.debug("Result: %r", result)
        if cached:
            self.script.root.http_cache[cache_key] = result
        return result

    def post(self,
             path,
             data,
             headers=None,
             cached=False,
             json=False,
             eof_mark=None):
        """
        Perform HTTP GET request
        :param path: URI
        :param headers: Dict of additional headers
        :param cached: Cache result
        :param json: Decode json if set to True
        :param eof_mark: Waiting eof_mark in stream for end session (perhaps device return length 0)
        """
        self.logger.debug("POST %s %s", path, data)
        if cached:
            cache_key = "post_%s" % path
            r = self.script.root.http_cache.get(cache_key)
            if r is not None:
                self.logger.debug("Use cached result")
                return r
        code, headers, result = fetch_sync(self.get_url(path),
                                           method="POST",
                                           headers=headers,
                                           request_timeout=60,
                                           follow_redirects=True,
                                           allow_proxy=False,
                                           validate_cert=False,
                                           eof_mark=eof_mark)
        # pylint: disable=superfluous-parens
        if not (200 <= code <= 299):  # noqa
            raise self.HTTPError(msg="HTTP Error (%s)" % result[:256],
                                 code=code)
        if json:
            try:
                return ujson.loads(result)
            except ValueError as e:
                raise self.HTTPError(msg="Failed to decode JSON: %s" % e)
        self.logger.debug("Result: %r", result)
        if cached:
            self.script.root.http_cache[cache_key] = result
        return result

    def close(self):
        pass
コード例 #14
0
ファイル: base.py プロジェクト: skripkar/noc
class MMLBase(object):
    name = "mml"
    iostream_class = None
    default_port = None
    BUFFER_SIZE = config.activator.buffer_size
    MATCH_TAIL = 256
    # Retries on immediate disconnect
    CONNECT_RETRIES = config.activator.connect_retries
    # Timeout after immediate disconnect
    CONNECT_TIMEOUT = config.activator.connect_timeout
    # compiled capabilities
    HAS_TCP_KEEPALIVE = hasattr(socket, "SO_KEEPALIVE")
    HAS_TCP_KEEPIDLE = hasattr(socket, "TCP_KEEPIDLE")
    HAS_TCP_KEEPINTVL = hasattr(socket, "TCP_KEEPINTVL")
    HAS_TCP_KEEPCNT = hasattr(socket, "TCP_KEEPCNT")
    HAS_TCP_NODELAY = hasattr(socket, "TCP_NODELAY")
    # Time until sending first keepalive probe
    KEEP_IDLE = 10
    # Keepalive packets interval
    KEEP_INTVL = 10
    # Terminate connection after N keepalive failures
    KEEP_CNT = 3

    def __init__(self, script, tos=None):
        self.script = script
        self.profile = script.profile
        self.logger = PrefixLoggerAdapter(self.script.logger, self.name)
        self.iostream = None
        self.ioloop = None
        self.command = None
        self.buffer = ""
        self.is_started = False
        self.result = None
        self.error = None
        self.is_closed = False
        self.close_timeout = None
        self.current_timeout = None
        self.tos = tos
        self.rx_mml_end = re.compile(self.script.profile.pattern_mml_end, re.MULTILINE)
        if self.script.profile.pattern_mml_continue:
            self.rx_mml_continue = re.compile(self.script.profile.pattern_mml_continue, re.MULTILINE)
        else:
            self.rx_mml_continue = None

    def close(self):
        self.script.close_current_session()
        self.close_iostream()
        if self.ioloop:
            self.logger.debug("Closing IOLoop")
            self.ioloop.close(all_fds=True)
            self.ioloop = None
        self.is_closed = True

    def close_iostream(self):
        if self.iostream:
            self.iostream.close()

    def deferred_close(self, session_timeout):
        if self.is_closed or not self.iostream:
            return
        self.logger.debug("Setting close timeout to %ss",
                          session_timeout)
        # Cannot call call_later directly due to
        # thread-safety problems
        # See tornado issue #1773
        tornado.ioloop.IOLoop.instance().add_callback(
            self._set_close_timeout,
            session_timeout
        )

    def _set_close_timeout(self, session_timeout):
        """
        Wrapper to deal with IOLoop.add_timeout thread safety problem
        :param session_timeout:
        :return:
        """
        self.close_timeout = tornado.ioloop.IOLoop.instance().call_later(
            session_timeout,
            self.close
        )

    def create_iostream(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        if self.tos:
            s.setsockopt(
                socket.IPPROTO_IP, socket.IP_TOS, self.tos
            )
        if self.HAS_TCP_NODELAY:
            s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
        if self.HAS_TCP_KEEPALIVE:
            s.setsockopt(
                socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1
            )
            if self.HAS_TCP_KEEPIDLE:
                s.setsockopt(socket.SOL_TCP,
                             socket.TCP_KEEPIDLE, self.KEEP_IDLE)
            if self.HAS_TCP_KEEPINTVL:
                s.setsockopt(socket.SOL_TCP,
                             socket.TCP_KEEPINTVL, self.KEEP_INTVL)
            if self.HAS_TCP_KEEPCNT:
                s.setsockopt(socket.SOL_TCP,
                             socket.TCP_KEEPCNT, self.KEEP_CNT)
        return self.iostream_class(s, self)

    def set_timeout(self, timeout):
        if timeout:
            self.logger.debug("Setting timeout: %ss", timeout)
            self.current_timeout = datetime.timedelta(seconds=timeout)
        else:
            if self.current_timeout:
                self.logger.debug("Resetting timeouts")
            self.current_timeout = None

    def set_script(self, script):
        self.script = script
        if self.close_timeout:
            tornado.ioloop.IOLoop.instance().remove_timeout(self.close_timeout)
            self.close_timeout = None

    @tornado.gen.coroutine
    def send(self, cmd):
        # @todo: Apply encoding
        cmd = str(cmd)
        self.logger.debug("Send: %r", cmd)
        yield self.iostream.write(cmd)

    @tornado.gen.coroutine
    def submit(self):
        # Create iostream and connect, when necessary
        if not self.iostream:
            self.iostream = self.create_iostream()
            address = (
                self.script.credentials.get("address"),
                self.script.credentials.get("cli_port", self.default_port)
            )
            self.logger.debug("Connecting %s", address)
            try:
                yield self.iostream.connect(address)
            except tornado.iostream.StreamClosedError:
                self.logger.debug("Connection refused")
                self.error = MMLConnectionRefused("Connection refused")
                raise tornado.gen.Return(None)
            self.logger.debug("Connected")
            yield self.iostream.startup()
        # Perform all necessary login procedures
        if not self.is_started:
            self.is_started = True
            yield self.send(self.profile.get_mml_login(self.script))
            yield self.get_mml_response()
            if self.error:
                self.error = MMLAuthFailed(str(self.error))
                raise tornado.gen.Return(None)
        # Send command
        yield self.send(self.command)
        r = yield self.get_mml_response()
        raise tornado.gen.Return(r)

    @tornado.gen.coroutine
    def get_mml_response(self):
        result = []
        header_sep = self.profile.mml_header_separator
        while True:
            r = yield self.read_until_end()
            r = r.strip()
            # Process header
            if header_sep not in r:
                self.result = ""
                self.error = MMLBadResponse("Missed header separator")
                raise tornado.gen.Return(None)
            header, r = r.split(header_sep, 1)
            code, msg = self.profile.parse_mml_header(header)
            if code:
                # MML Error
                self.result = ""
                self.error = MMLError("%s (code=%s)" % (msg, code))
                raise tornado.gen.Return(None)
            # Process continuation
            if self.rx_mml_continue:
                # Process continued block
                offset = max(0, len(r) - self.MATCH_TAIL)
                match = self.rx_mml_continue.search(r, offset)
                if match:
                    self.logger.debug("Continuing in the next block")
                    result += [r[:match.start()]]
                    continue
            result += [r]
            break
        self.result = "".join(result)
        raise tornado.gen.Return(self.result)

    def execute(self, cmd, **kwargs):
        """
        Perform command and return result
        :param cmd:
        :param kwargs:
        :return:
        """
        if self.close_timeout:
            self.logger.debug("Removing close timeout")
            self.ioloop.remove_timeout(self.close_timeout)
            self.close_timeout = None
        self.buffer = ""
        self.command = self.profile.get_mml_command(cmd, **kwargs)
        self.error = None
        if not self.ioloop:
            self.logger.debug("Creating IOLoop")
            self.ioloop = tornado.ioloop.IOLoop()
        with Span(server=self.script.credentials.get("address"),
                  service=self.name, in_label=self.command) as s:
            self.ioloop.run_sync(self.submit)
            if self.error:
                if s:
                    s.error_text = str(self.error)
                raise self.error
            else:
                return self.result

    @tornado.gen.coroutine
    def read_until_end(self):
        connect_retries = self.CONNECT_RETRIES
        while True:
            try:
                f = self.iostream.read_bytes(self.BUFFER_SIZE,
                                             partial=True)
                if self.current_timeout:
                    r = yield tornado.gen.with_timeout(
                        self.current_timeout,
                        f
                    )
                else:
                    r = yield f
            except tornado.iostream.StreamClosedError:
                # Check if remote end closes connection just
                # after connection established
                if not self.is_started and connect_retries:
                    self.logger.info(
                        "Connection reset. %d retries left. Waiting %d seconds",
                        connect_retries, self.CONNECT_TIMEOUT
                    )
                    while connect_retries:
                        yield tornado.gen.sleep(self.CONNECT_TIMEOUT)
                        connect_retries -= 1
                        self.iostream = self.create_iostream()
                        address = (
                            self.script.credentials.get("address"),
                            self.script.credentials.get("cli_port", self.default_port)
                        )
                        self.logger.debug("Connecting %s", address)
                        try:
                            yield self.iostream.connect(address)
                            break
                        except tornado.iostream.StreamClosedError:
                            if not connect_retries:
                                raise tornado.iostream.StreamClosedError()
                    continue
                else:
                    raise tornado.iostream.StreamClosedError()
            except tornado.gen.TimeoutError:
                self.logger.info("Timeout error")
                raise tornado.gen.TimeoutError("Timeout")
            self.logger.debug("Received: %r", r)
            self.buffer += r
            offset = max(0, len(self.buffer) - self.MATCH_TAIL)
            match = self.rx_mml_end.search(self.buffer, offset)
            if match:
                self.logger.debug("End of the block")
                r = self.buffer[:match.start()]
                self.buffer = self.buffer[match.end()]
                raise tornado.gen.Return(r)

    def shutdown_session(self):
        if self.profile.shutdown_session:
            self.logger.debug("Shutdown session")
            self.profile.shutdown_session(self.script)