示例#1
0
文件: main.py 项目: stschwar/scion
class CertServer(SCIONElement):
    """
    The SCION Certificate Server.
    """
    SERVICE_TYPE = CERTIFICATE_SERVICE
    # ZK path for incoming cert chains
    ZK_CC_CACHE_PATH = "cert_chain_cache"
    # ZK path for incoming TRCs
    ZK_TRC_CACHE_PATH = "trc_cache"
    ZK_DRKEY_PATH = "drkey_cache"

    def __init__(self,
                 server_id,
                 conf_dir,
                 spki_cache_dir=GEN_CACHE_PATH,
                 prom_export=None):
        """
        :param str server_id: server identifier.
        :param str conf_dir: configuration directory.
        :param str prom_export: prometheus export address.
        """
        super().__init__(server_id,
                         conf_dir,
                         spki_cache_dir=spki_cache_dir,
                         prom_export=prom_export)
        self.config = self._load_as_conf()
        cc_labels = {**self._labels, "type": "cc"} if self._labels else None
        trc_labels = {**self._labels, "type": "trc"} if self._labels else None
        drkey_labels = {
            **self._labels, "type": "drkey"
        } if self._labels else None
        self.cc_requests = RequestHandler.start(
            "CC Requests",
            self._check_cc,
            self._fetch_cc,
            self._reply_cc,
            labels=cc_labels,
        )
        self.trc_requests = RequestHandler.start(
            "TRC Requests",
            self._check_trc,
            self._fetch_trc,
            self._reply_trc,
            labels=trc_labels,
        )
        self.drkey_protocol_requests = RequestHandler.start(
            "DRKey Requests",
            self._check_drkey,
            self._fetch_drkey,
            self._reply_proto_drkey,
            labels=drkey_labels,
        )

        self.CTRL_PLD_CLASS_MAP = {
            PayloadClass.CERT: {
                CertMgmtType.CERT_CHAIN_REQ: self.process_cert_chain_request,
                CertMgmtType.CERT_CHAIN_REPLY: self.process_cert_chain_reply,
                CertMgmtType.TRC_REQ: self.process_trc_request,
                CertMgmtType.TRC_REPLY: self.process_trc_reply,
            },
            PayloadClass.DRKEY: {
                DRKeyMgmtType.FIRST_ORDER_REQUEST: self.process_drkey_request,
                DRKeyMgmtType.FIRST_ORDER_REPLY: self.process_drkey_reply,
            },
        }

        zkid = ZkID.from_values(self.addr.isd_as, self.id,
                                [(self.addr.host, self._port)]).pack()
        self.zk = Zookeeper(self.topology.isd_as, CERTIFICATE_SERVICE, zkid,
                            self.topology.zookeepers)
        self.zk.retry("Joining party", self.zk.party_setup)
        self.trc_cache = ZkSharedCache(self.zk, self.ZK_TRC_CACHE_PATH,
                                       self._cached_trcs_handler)
        self.cc_cache = ZkSharedCache(self.zk, self.ZK_CC_CACHE_PATH,
                                      self._cached_certs_handler)
        self.drkey_cache = ZkSharedCache(self.zk, self.ZK_DRKEY_PATH,
                                         self._cached_drkeys_handler)
        self.signing_key = get_sig_key(self.conf_dir)
        self.private_key = get_enc_key(self.conf_dir)
        self.drkey_secrets = ExpiringDict(DRKEY_MAX_SV, DRKEY_MAX_TTL)
        self.first_order_drkeys = ExpiringDict(DRKEY_MAX_KEYS, DRKEY_MAX_TTL)

    def worker(self):
        """
        Worker thread that takes care of reading shared entries from ZK, and
        handling master election.
        """
        worker_cycle = 1.0
        start = SCIONTime.get_time()
        while self.run_flag.is_set():
            sleep_interval(start, worker_cycle, "CS.worker cycle",
                           self._quiet_startup())
            start = SCIONTime.get_time()
            # Update IS_MASTER metric.
            if self._labels:
                IS_MASTER.labels(**self._labels).set(int(self.zk.have_lock()))
            try:
                self.zk.wait_connected()
                self.trc_cache.process()
                self.cc_cache.process()
                self.drkey_cache.process()
                # Try to become a master.
                ret = self.zk.get_lock(lock_timeout=0, conn_timeout=0)
                if ret:  # Either got the lock, or already had it.
                    if ret == ZK_LOCK_SUCCESS:
                        logging.info("Became master")
                    self.trc_cache.expire(worker_cycle * 10)
                    self.cc_cache.expire(worker_cycle * 10)
                    self.drkey_cache.expire(worker_cycle * 10)
            except ZkNoConnection:
                logging.warning('worker(): ZkNoConnection')
                pass

    def _cached_trcs_handler(self, raw_entries):
        """
        Handles cached (through ZK) TRCs, passed as a list.
        """
        for raw in raw_entries:
            trc = TRC.from_raw(raw.decode('utf-8'))
            rep = CtrlPayload(CertMgmt(TRCReply.from_values(trc)))
            self.process_trc_reply(rep, None, from_zk=True)
        if len(raw_entries) > 0:
            logging.debug("Processed %s trcs from ZK", len(raw_entries))

    def _cached_certs_handler(self, raw_entries):
        """
        Handles cached (through ZK) chains, passed as a list.
        """
        for raw in raw_entries:
            cert = CertificateChain.from_raw(raw.decode('utf-8'))
            rep = CtrlPayload(CertMgmt(CertChainReply.from_values(cert)))
            self.process_cert_chain_reply(rep, None, from_zk=True)
        if len(raw_entries) > 0:
            logging.debug("Processed %s certs from ZK", len(raw_entries))

    def _cached_drkeys_handler(self, raw_entries):
        for raw in raw_entries:
            msg = CtrlPayload(DRKeyMgmt(DRKeyReply.from_raw(raw)))
            self.process_drkey_reply(msg, None, from_zk=True)

    def _share_object(self, pld, is_trc):
        """
        Share path segments (via ZK) with other path servers.
        """
        pld_packed = pld.pack()
        pld_hash = crypto_hash(pld_packed).hex()
        try:
            if is_trc:
                self.trc_cache.store(
                    "%s-%s" % (pld_hash, SCIONTime.get_time()), pld_packed)
            else:
                self.cc_cache.store("%s-%s" % (pld_hash, SCIONTime.get_time()),
                                    pld_packed)
        except ZkNoConnection:
            logging.warning("Unable to store %s in shared path: "
                            "no connection to ZK" % "TRC" if is_trc else "CC")
            return
        logging.debug("%s stored in ZK: %s" %
                      ("TRC" if is_trc else "CC", pld_hash))

    def process_cert_chain_request(self, cpld, meta):
        """Process a certificate chain request."""
        cmgt = cpld.union
        req = cmgt.union
        assert isinstance(req, CertChainRequest), type(req)
        key = req.isd_as(), req.p.version
        logging.info("Cert chain request received for %sv%s from %s", *key,
                     meta)
        REQS_TOTAL.labels(**self._labels, type="cc").inc()
        local = meta.ia == self.addr.isd_as
        if not self._check_cc(key):
            if not local:
                logging.warning(
                    "Dropping CC request from %s for %sv%s: "
                    "CC not found && requester is not local)", meta, *key)
            else:
                self.cc_requests.put((key, (meta, req, cpld.req_id)))
            return
        self._reply_cc(key, (meta, req, cpld.req_id))

    def process_cert_chain_reply(self, cpld, meta, from_zk=False):
        """Process a certificate chain reply."""
        cmgt = cpld.union
        rep = cmgt.union
        assert isinstance(rep, CertChainReply), type(rep)
        ia_ver = rep.chain.get_leaf_isd_as_ver()
        logging.info("Cert chain reply received for %sv%s (ZK: %s)" %
                     (ia_ver[0], ia_ver[1], from_zk))
        self.trust_store.add_cert(rep.chain)
        if not from_zk:
            self._share_object(rep.chain, is_trc=False)
        # Reply to all requests for this certificate chain
        self.cc_requests.put((ia_ver, None))

    def _check_cc(self, key):
        isd_as, ver = key
        ver = None if ver == CertChainRequest.NEWEST_VERSION else ver
        cert_chain = self.trust_store.get_cert(isd_as, ver)
        if cert_chain:
            return True
        logging.debug('Cert chain not found for %sv%s', *key)
        return False

    def _fetch_cc(self, key, req_info):
        # Do not attempt to fetch the CertChain from a remote AS if the cacheOnly flag is set.
        _, orig_req, _ = req_info
        if orig_req.p.cacheOnly:
            return
        self._send_cc_request(*key)

    def _send_cc_request(self, isd_as, ver):
        req = CertChainRequest.from_values(isd_as, ver, cache_only=True)
        path_meta = self._get_path_via_sciond(isd_as)
        if path_meta:
            meta = self._build_meta(isd_as,
                                    host=SVCType.CS_A,
                                    path=path_meta.fwd_path())
            req_id = mk_ctrl_req_id()
            self.send_meta(CtrlPayload(CertMgmt(req), req_id=req_id), meta)
            logging.info(
                "Cert chain request sent to %s via [%s]: %s [id: %016x]", meta,
                path_meta.short_desc(), req.short_desc(), req_id)
        else:
            logging.warning(
                "Cert chain request (for %s) not sent: "
                "no path found", req.short_desc())

    def _reply_cc(self, key, req_info):
        isd_as, ver = key
        ver = None if ver == CertChainRequest.NEWEST_VERSION else ver
        meta = req_info[0]
        req_id = req_info[2]
        cert_chain = self.trust_store.get_cert(isd_as, ver)
        self.send_meta(
            CtrlPayload(CertMgmt(CertChainReply.from_values(cert_chain)),
                        req_id=req_id), meta)
        logging.info("Cert chain for %sv%s sent to %s [id: %016x]", isd_as,
                     ver, meta, req_id)

    def process_trc_request(self, cpld, meta):
        """Process a TRC request."""
        cmgt = cpld.union
        req = cmgt.union
        assert isinstance(req, TRCRequest), type(req)
        key = req.isd_as()[0], req.p.version
        logging.info("TRC request received for %sv%s from %s [id: %s]", *key,
                     meta, cpld.req_id_str())
        REQS_TOTAL.labels(**self._labels, type="trc").inc()
        local = meta.ia == self.addr.isd_as
        if not self._check_trc(key):
            if not local:
                logging.warning(
                    "Dropping TRC request from %s for %sv%s: "
                    "TRC not found && requester is not local)", meta, *key)
            else:
                self.trc_requests.put((key, (meta, req, cpld.req_id)))
            return
        self._reply_trc(key, (meta, req, cpld.req_id))

    def process_trc_reply(self, cpld, meta, from_zk=False):
        """
        Process a TRC reply.

        :param trc_rep: TRC reply.
        :type trc_rep: TRCReply
        """
        cmgt = cpld.union
        trc_rep = cmgt.union
        assert isinstance(trc_rep, TRCReply), type(trc_rep)
        isd, ver = trc_rep.trc.get_isd_ver()
        logging.info("TRCReply received for ISD %sv%s, ZK: %s [id: %s]", isd,
                     ver, from_zk, cpld.req_id_str())
        self.trust_store.add_trc(trc_rep.trc)
        if not from_zk:
            self._share_object(trc_rep.trc, is_trc=True)
        # Reply to all requests for this TRC
        self.trc_requests.put(((isd, ver), None))

    def _check_trc(self, key):
        isd, ver = key
        ver = None if ver == TRCRequest.NEWEST_VERSION else ver
        trc = self.trust_store.get_trc(isd, ver)
        if trc:
            return True
        logging.debug('TRC not found for %sv%s', *key)
        return False

    def _fetch_trc(self, key, req_info):
        # Do not attempt to fetch the TRC from a remote AS if the cacheOnly flag is set.
        _, orig_req, _ = req_info
        if orig_req.p.cacheOnly:
            return
        self._send_trc_request(*key)

    def _send_trc_request(self, isd, ver):
        trc_req = TRCRequest.from_values(isd, ver, cache_only=True)
        path_meta = self._get_path_via_sciond(trc_req.isd_as())
        if path_meta:
            meta = self._build_meta(path_meta.dst_ia(),
                                    host=SVCType.CS_A,
                                    path=path_meta.fwd_path())
            req_id = mk_ctrl_req_id()
            self.send_meta(CtrlPayload(CertMgmt(trc_req), req_id=req_id), meta)
            logging.info("TRC request sent to %s via [%s]: %s [id: %016x]",
                         meta, path_meta.short_desc(), trc_req.short_desc(),
                         req_id)
        else:
            logging.warning("TRC request not sent for %s: no path found.",
                            trc_req.short_desc())

    def _reply_trc(self, key, req_info):
        isd, ver = key
        ver = None if ver == TRCRequest.NEWEST_VERSION else ver
        meta = req_info[0]
        req_id = req_info[2]
        trc = self.trust_store.get_trc(isd, ver)
        self.send_meta(
            CtrlPayload(CertMgmt(TRCReply.from_values(trc)), req_id=req_id),
            meta)
        logging.info("TRC for %sv%s sent to %s [id: %016x]", isd, ver, meta,
                     req_id)

    def process_drkey_request(self, cpld, meta):
        """
        Process first order DRKey requests from other ASes.

        :param DRKeyRequest req: the DRKey request
        :param UDPMetadata meta: the metadata
        """
        dpld = cpld.union
        req = dpld.union
        assert isinstance(req, DRKeyRequest), type(req)
        logging.info("DRKeyRequest received from %s: %s [id: %s]", meta,
                     req.short_desc(), cpld.req_id_str())
        REQS_TOTAL.labels(**self._labels, type="drkey").inc()
        try:
            cert = self._verify_drkey_request(req, meta)
        except SCIONVerificationError as e:
            logging.warning("Invalid DRKeyRequest from %s. Reason %s: %s",
                            meta, e, req.short_desc())
            return
        sv = self._get_drkey_secret(get_drkey_exp_time(req.p.flags.prefetch))
        cert_version = self.trust_store.get_cert(
            self.addr.isd_as).certs[0].version
        trc_version = self.trust_store.get_trc(self.addr.isd_as[0]).version
        rep = get_drkey_reply(sv, self.addr.isd_as, meta.ia, self.private_key,
                              self.signing_key, cert_version, cert,
                              trc_version)
        self.send_meta(CtrlPayload(DRKeyMgmt(rep), req_id=cpld.req_id), meta)
        logging.info("DRKeyReply sent to %s: %s [id: %s]", meta,
                     req.short_desc(), cpld.req_id_str())

    def _verify_drkey_request(self, req, meta):
        """
        Verify that the first order DRKey request is legit.
        I.e. the signature is valid, the correct ISD AS is queried, timestamp is recent.

        :param DRKeyRequest req: the first order DRKey request.
        :param UDPMetadata meta: the metadata.
        :returns Certificate of the requester.
        :rtype: Certificate
        :raises: SCIONVerificationError
        """
        if self.addr.isd_as != req.isd_as:
            raise SCIONVerificationError("Request for other ISD-AS: %s" %
                                         req.isd_as)
        if drkey_time() - req.p.timestamp > DRKEY_REQUEST_TIMEOUT:
            raise SCIONVerificationError(
                "Expired request from %s. %ss old. Max %ss" %
                (meta.ia, drkey_time() - req.p.timestamp,
                 DRKEY_REQUEST_TIMEOUT))
        trc = self.trust_store.get_trc(meta.ia[0])
        chain = self.trust_store.get_cert(meta.ia, req.p.certVer)
        err = []
        if not chain:
            self._send_cc_request(meta.ia, req.p.certVer)
            err.append("Certificate not present for %s(v: %s)" %
                       (meta.ia, req.p.certVer))
        if not trc:
            self._send_trc_request(meta.ia[0], req.p.trcVer)
            err.append("TRC not present for %s(v: %s)" %
                       (meta.ia[0], req.p.trcVer))
        if err:
            raise SCIONVerificationError(", ".join(err))
        raw = drkey_signing_input_req(req.isd_as, req.p.flags.prefetch,
                                      req.p.timestamp)
        try:
            verify_sig_chain_trc(raw, req.p.signature, meta.ia, chain, trc)
        except SCIONVerificationError as e:
            raise SCIONVerificationError(str(e))
        return chain.certs[0]

    def process_drkey_reply(self, cpld, meta, from_zk=False):
        """
        Process first order DRKey reply from other ASes.

        :param DRKeyReply rep: the received DRKey reply
        :param UDPMetadata meta: the metadata
        :param Bool from_zk: if the reply has been received from Zookeeper
        """
        dpld = cpld.union
        rep = dpld.union
        assert isinstance(rep, DRKeyReply), type(rep)
        logging.info("DRKeyReply received from %s: %s [id: %s]", meta,
                     rep.short_desc(), cpld.req_id_str())
        src = meta or "ZK"

        try:
            cert = self._verify_drkey_reply(rep, meta)
            raw = decrypt_drkey(rep.p.cipher, self.private_key,
                                cert.subject_enc_key_raw)
        except SCIONVerificationError as e:
            logging.info("Invalid DRKeyReply from %s. Reason %s: %s", src, e,
                         rep.short_desc())
            return
        except CryptoError as e:
            logging.info("Unable to decrypt DRKeyReply from %s. Reason %s: %s",
                         src, e, rep.short_desc())
            return
        drkey = FirstOrderDRKey(rep.isd_as, self.addr.isd_as, rep.p.expTime,
                                raw)
        self.first_order_drkeys[drkey] = drkey
        if not from_zk:
            pld_packed = rep.copy().pack()
            try:
                self.drkey_cache.store("%s-%s" % (rep.isd_as, rep.p.expTime),
                                       pld_packed)
            except ZkNoConnection:
                logging.warning("Unable to store DRKey for %s in shared path: "
                                "no connection to ZK" % rep.isd_as)
                return
        self.drkey_protocol_requests.put((drkey, None))

    def _verify_drkey_reply(self, rep, meta):
        """
        Verify that the first order DRKey reply is legit.
        I.e. the signature matches, timestamp is recent.

        :param DRKeyReply rep: the first order DRKey reply.
        :param UDPMetadata meta: the metadata.
        :returns Certificate of the responder.
        :rtype: Certificate
        :raises: SCIONVerificationError
        """
        if meta and meta.ia != rep.isd_as:
            raise SCIONVerificationError("Response from other ISD-AS: %s" %
                                         rep.isd_as)
        if drkey_time() - rep.p.timestamp > DRKEY_REQUEST_TIMEOUT:
            raise SCIONVerificationError(
                "Expired reply from %s. %ss old. Max %ss" %
                (rep.isd_as, drkey_time() - rep.p.timestamp,
                 DRKEY_REQUEST_TIMEOUT))
        trc = self.trust_store.get_trc(rep.isd_as[0])
        chain = self.trust_store.get_cert(rep.isd_as, rep.p.certVerSrc)
        err = []
        if not chain:
            self._send_cc_request(rep.isd_as, rep.p.certVerSrc)
            err.append("Certificate not present for %s(v: %s)" %
                       (rep.isd_as, rep.p.certVerSrc))
        if not trc:
            self._send_trc_request(rep.isd_as[0], rep.p.trcVer)
            err.append("TRC not present for %s(v: %s)" %
                       (rep.isd_as[0], rep.p.trcVer))
        if err:
            raise SCIONVerificationError(", ".join(err))
        raw = get_signing_input_rep(rep.isd_as, rep.p.timestamp, rep.p.expTime,
                                    rep.p.cipher)
        try:
            verify_sig_chain_trc(raw, rep.p.signature, rep.isd_as, chain, trc)
        except SCIONVerificationError as e:
            raise SCIONVerificationError(str(e))
        return chain.certs[0]

    def _check_drkey(self, drkey):
        """
        Check if first order DRKey with the same (SrcIA, DstIA, expTime)
        is available.

        :param FirstOrderDRKey drkey: the searched DRKey.
        :returns: if the the first order DRKey is available.
        :rtype: Bool
        """
        if drkey in self.first_order_drkeys:
            return True
        return False

    def _fetch_drkey(self, drkey, _):
        """
        Fetch missing first order DRKey with the same (SrcIA, DstIA, expTime).

        :param FirstOrderDRKey drkey: The missing DRKey.
        """
        cert = self.trust_store.get_cert(self.addr.isd_as)
        trc = self.trust_store.get_trc(self.addr.isd_as[0])
        if not cert or not trc:
            logging.warning(
                "DRKeyRequest for %s not sent. Own CertChain/TRC not present.",
                drkey.src_ia)
            return
        req = get_drkey_request(drkey.src_ia, False, self.signing_key,
                                cert.certs[0].version, trc.version)
        path_meta = self._get_path_via_sciond(drkey.src_ia)
        if path_meta:
            meta = self._build_meta(drkey.src_ia,
                                    host=SVCType.CS_A,
                                    path=path_meta.fwd_path())
            req_id = mk_ctrl_req_id()
            self.send_meta(CtrlPayload(DRKeyMgmt(req)), meta)
            logging.info("DRKeyRequest (%s) sent to %s via %s [id: %016x]",
                         req.short_desc(), meta, path_meta, req_id)
        else:
            logging.warning("DRKeyRequest (for %s) not sent", req.short_desc())

    def _reply_proto_drkey(self, drkey, meta):
        pass  # TODO(roosd): implement in future PR

    def _get_drkey_secret(self, exp_time):
        """
        Get the drkey secret. A new secret is initialized if no secret is found.

        :param int exp_time: expiration time of the drkey secret
        :return: the according drkey secret
        :rtype: DRKeySecretValue
        """
        sv = self.drkey_secrets.get(exp_time)
        if not sv:
            sv = DRKeySecretValue(
                kdf(self.config.master_as_key, b"Derive DRKey Key"), exp_time)
            self.drkey_secrets[sv.exp_time] = sv
        return sv

    def _init_metrics(self):
        super()._init_metrics()
        for type_ in ("trc", "cc", "drkey"):
            REQS_TOTAL.labels(**self._labels, type=type_).inc(0)
        IS_MASTER.labels(**self._labels).set(0)

    def run(self):
        """
        Run an instance of the Cert Server.
        """
        threading.Thread(target=thread_safety_net,
                         args=(self.worker, ),
                         name="CS.worker",
                         daemon=True).start()
        super().run()
示例#2
0
文件: base.py 项目: stschwar/scion
class PathServer(SCIONElement, metaclass=ABCMeta):
    """
    The SCION Path Server.
    """
    SERVICE_TYPE = PATH_SERVICE
    MAX_SEG_NO = 5  # TODO: replace by config variable.
    # ZK path for incoming PATHs
    ZK_PATH_CACHE_PATH = "path_cache"
    # ZK path for incoming REVs
    ZK_REV_CACHE_PATH = "rev_cache"
    # Max number of segments per propagation packet
    PROP_LIMIT = 5
    # Max number of segments per ZK cache entry
    ZK_SHARE_LIMIT = 10
    # Time to store revocations in zookeeper
    ZK_REV_OBJ_MAX_AGE = HASHTREE_EPOCH_TIME
    # TTL of segments in the queue for ZK (in seconds)
    SEGS_TO_ZK_TTL = 10 * 60

    def __init__(self,
                 server_id,
                 conf_dir,
                 spki_cache_dir=GEN_CACHE_PATH,
                 prom_export=None):
        """
        :param str server_id: server identifier.
        :param str conf_dir: configuration directory.
        :param str prom_export: prometheus export address.
        """
        super().__init__(server_id,
                         conf_dir,
                         spki_cache_dir=spki_cache_dir,
                         prom_export=prom_export)
        self.config = self._load_as_conf()
        down_labels = {
            **self._labels, "type": "down"
        } if self._labels else None
        core_labels = {
            **self._labels, "type": "core"
        } if self._labels else None
        self.down_segments = PathSegmentDB(max_res_no=self.MAX_SEG_NO,
                                           labels=down_labels)
        self.core_segments = PathSegmentDB(max_res_no=self.MAX_SEG_NO,
                                           labels=core_labels)
        # Dict of pending requests.
        self.pending_req = defaultdict(
            lambda: ExpiringDict(1000, PATH_REQ_TOUT))
        self.pen_req_lock = threading.Lock()
        self._request_logger = None
        # Used when l/cPS doesn't have up/dw-path.
        self.waiting_targets = defaultdict(list)
        self.revocations = RevCache(labels=self._labels)
        # A mapping from (hash tree root of AS, IFID) to segments
        self.htroot_if2seg = ExpiringDict(1000,
                                          self.config.revocation_tree_ttl)
        self.htroot_if2seglock = Lock()
        self.CTRL_PLD_CLASS_MAP = {
            PayloadClass.PATH: {
                PMT.IFSTATE_INFOS: self.handle_ifstate_infos,
                PMT.REQUEST: self.path_resolution,
                PMT.REPLY: self.handle_path_reply,
                PMT.REG: self.handle_seg_recs,
                PMT.REVOCATION: self._handle_revocation,
                PMT.SYNC: self.handle_seg_recs,
            },
            PayloadClass.CERT: {
                CertMgmtType.CERT_CHAIN_REQ: self.process_cert_chain_request,
                CertMgmtType.CERT_CHAIN_REPLY: self.process_cert_chain_reply,
                CertMgmtType.TRC_REPLY: self.process_trc_reply,
                CertMgmtType.TRC_REQ: self.process_trc_request,
            },
        }
        self.SCMP_PLD_CLASS_MAP = {
            SCMPClass.PATH: {
                SCMPPathClass.REVOKED_IF: self._handle_scmp_revocation,
            },
        }
        self._segs_to_zk = ExpiringDict(1000, self.SEGS_TO_ZK_TTL)
        self._revs_to_zk = ExpiringDict(1000, HASHTREE_EPOCH_TIME)
        self._zkid = ZkID.from_values(self.addr.isd_as, self.id,
                                      [(self.addr.host, self._port)])
        self.zk = Zookeeper(self.topology.isd_as, PATH_SERVICE,
                            self._zkid.copy().pack(), self.topology.zookeepers)
        self.zk.retry("Joining party", self.zk.party_setup)
        self.path_cache = ZkSharedCache(self.zk, self.ZK_PATH_CACHE_PATH,
                                        self._handle_paths_from_zk)
        self.rev_cache = ZkSharedCache(self.zk, self.ZK_REV_CACHE_PATH,
                                       self._rev_entries_handler)
        self._init_request_logger()

    def worker(self):
        """
        Worker thread that takes care of reading shared paths from ZK, and
        handling master election for core servers.
        """
        worker_cycle = 1.0
        start = SCIONTime.get_time()
        while self.run_flag.is_set():
            sleep_interval(start, worker_cycle, "cPS.worker cycle",
                           self._quiet_startup())
            start = SCIONTime.get_time()
            try:
                self.zk.wait_connected()
                self.path_cache.process()
                self.rev_cache.process()
                # Try to become a master.
                ret = self.zk.get_lock(lock_timeout=0, conn_timeout=0)
                if ret:  # Either got the lock, or already had it.
                    if ret == ZK_LOCK_SUCCESS:
                        logging.info("Became master")
                    self.path_cache.expire(self.config.propagation_time * 10)
                    self.rev_cache.expire(self.ZK_REV_OBJ_MAX_AGE)
            except ZkNoConnection:
                logging.warning('worker(): ZkNoConnection')
                pass
            self._update_master()
            self._propagate_and_sync()
            self._handle_pending_requests()
            self._update_metrics()

    def _update_master(self):
        pass

    def _rev_entries_handler(self, raw_entries):
        for raw in raw_entries:
            rev_info = RevocationInfo.from_raw(raw)
            try:
                rev_info.validate()
            except SCIONBaseError as e:
                logging.warning("Failed to validate RevInfo from zk: %s\n%s",
                                e, rev_info.short_desc())
                continue
            self._remove_revoked_segments(rev_info)

    def _add_rev_mappings(self, pcb):
        """
        Add if revocation token to segment ID mappings.
        """
        segment_id = pcb.get_hops_hash()
        with self.htroot_if2seglock:
            for asm in pcb.iter_asms():
                hof = asm.pcbm(0).hof()
                egress_h = (asm.p.hashTreeRoot, hof.egress_if)
                self.htroot_if2seg.setdefault(egress_h, set()).add(segment_id)
                ingress_h = (asm.p.hashTreeRoot, hof.ingress_if)
                self.htroot_if2seg.setdefault(ingress_h, set()).add(segment_id)

    @abstractmethod
    def _handle_up_segment_record(self, pcb, **kwargs):
        raise NotImplementedError

    @abstractmethod
    def _handle_down_segment_record(self, pcb, **kwargs):
        raise NotImplementedError

    @abstractmethod
    def _handle_core_segment_record(self, pcb, **kwargs):
        raise NotImplementedError

    def _add_segment(self, pcb, seg_db, name, reverse=False):
        res = seg_db.update(pcb, reverse=reverse)
        if res == DBResult.ENTRY_ADDED:
            self._add_rev_mappings(pcb)
            logging.info("%s-Segment registered: %s", name, pcb.short_id())
            return True
        elif res == DBResult.ENTRY_UPDATED:
            self._add_rev_mappings(pcb)
            logging.debug("%s-Segment updated: %s", name, pcb.short_id())
        return False

    def handle_ifstate_infos(self, cpld, meta):
        """
        Handles IFStateInfos.

        :param IFStatePayload infos: The state info objects.
        """
        pmgt = cpld.union
        infos = pmgt.union
        assert isinstance(infos, IFStatePayload), type(infos)
        for info in infos.iter_infos():
            if not info.p.active and info.p.revInfo:
                rev_info = info.rev_info()
                try:
                    rev_info.validate()
                except SCIONBaseError as e:
                    logging.warning(
                        "Failed to validate IFStateInfo RevInfo from %s: %s\n%s",
                        meta, e, rev_info.short_desc())
                    continue
                self._handle_revocation(CtrlPayload(PathMgmt(info.rev_info())),
                                        meta)

    def _handle_scmp_revocation(self, pld, meta):
        rev_info = RevocationInfo.from_raw(pld.info.rev_info)
        try:
            rev_info.validate()
        except SCIONBaseError as e:
            logging.warning("Failed to validate SCMP RevInfo from %s: %s\n%s",
                            meta, e, rev_info.short_desc())
            return
        self._handle_revocation(CtrlPayload(PathMgmt(rev_info)), meta)

    def _handle_revocation(self, cpld, meta):
        """
        Handles a revocation of a segment, interface or hop.

        :param rev_info: The RevocationInfo object.
        """
        pmgt = cpld.union
        rev_info = pmgt.union
        assert isinstance(rev_info, RevocationInfo), type(rev_info)
        # Validate before checking for presense in self.revocations, as that will trigger an assert
        # failure if the rev_info is invalid.
        try:
            rev_info.validate()
        except SCIONBaseError as e:
            # Validation already done in the IFStateInfo and SCMP paths, so a failure here means
            # it's from a CtrlPld.
            logging.warning(
                "Failed to validate CtrlPld RevInfo from %s: %s\n%s", meta, e,
                rev_info.short_desc())
            return

        if rev_info in self.revocations:
            return
        logging.debug("Received revocation from %s: %s", meta,
                      rev_info.short_desc())
        try:
            rev_info.validate()
        except SCIONBaseError as e:
            logging.warning("Failed to validate RevInfo from %s: %s", meta, e)
            return
        if meta.ia[0] != self.addr.isd_as[0]:
            logging.info(
                "Dropping revocation received from a different ISD. Src: %s RevInfo: %s"
                % (meta, rev_info.short_desc()))
            return
        self.revocations.add(rev_info)
        self._revs_to_zk[rev_info] = rev_info.copy().pack(
        )  # have to pack copy
        # Remove segments that contain the revoked interface.
        self._remove_revoked_segments(rev_info)
        # Forward revocation to other path servers.
        self._forward_revocation(rev_info, meta)

    def _remove_revoked_segments(self, rev_info):
        """
        Try the previous and next hashes as possible astokens,
        and delete any segment that matches

        :param rev_info: The revocation info
        :type rev_info: RevocationInfo
        """
        if ConnectedHashTree.verify_epoch(
                rev_info.p.epoch) != ConnectedHashTree.EPOCH_OK:
            return
        (hash01, hash12) = ConnectedHashTree.get_possible_hashes(rev_info)
        if_id = rev_info.p.ifID

        with self.htroot_if2seglock:
            down_segs_removed = 0
            core_segs_removed = 0
            up_segs_removed = 0
            for h in (hash01, hash12):
                for sid in self.htroot_if2seg.pop((h, if_id), []):
                    if self.down_segments.delete(
                            sid) == DBResult.ENTRY_DELETED:
                        down_segs_removed += 1
                    if self.core_segments.delete(
                            sid) == DBResult.ENTRY_DELETED:
                        core_segs_removed += 1
                    if not self.topology.is_core_as:
                        if (self.up_segments.delete(sid) ==
                                DBResult.ENTRY_DELETED):
                            up_segs_removed += 1
            logging.debug(
                "Removed segments revoked by [%s]: UP: %d DOWN: %d CORE: %d" %
                (rev_info.short_desc(), up_segs_removed, down_segs_removed,
                 core_segs_removed))

    @abstractmethod
    def _forward_revocation(self, rev_info, meta):
        """
        Forwards a revocation to other path servers that need to be notified.

        :param rev_info: The RevInfo object.
        :param meta: The MessageMeta object.
        """
        raise NotImplementedError

    def _send_path_segments(self,
                            req,
                            req_id,
                            meta,
                            logger,
                            up=None,
                            core=None,
                            down=None):
        """
        Sends path-segments to requester (depending on Path Server's location).
        """
        up = up or set()
        core = core or set()
        down = down or set()
        all_segs = up | core | down
        if not all_segs:
            logger.warning("No segments to send for request: %s from: %s" %
                           (req.short_desc(), meta))
            return
        revs_to_add = self._peer_revs_for_segs(all_segs)
        recs = PathSegmentRecords.from_values(
            {
                PST.UP: up,
                PST.CORE: core,
                PST.DOWN: down
            }, revs_to_add)
        pld = PathSegmentReply.from_values(req.copy(), recs)
        self.send_meta(CtrlPayload(PathMgmt(pld), req_id=req_id), meta)
        logger.info("Sending PATH_REPLY with %d segment(s).", len(all_segs))

    def _peer_revs_for_segs(self, segs):
        """Returns a list of peer revocations for segments in 'segs'."""
        def _handle_one_seg(seg):
            for asm in seg.iter_asms():
                for pcbm in asm.iter_pcbms(1):
                    hof = pcbm.hof()
                    for if_id in [hof.ingress_if, hof.egress_if]:
                        rev_info = self.revocations.get((asm.isd_as(), if_id))
                        if rev_info:
                            revs_to_add.add(rev_info.copy())
                            return

        revs_to_add = set()
        for seg in segs:
            _handle_one_seg(seg)

        return list(revs_to_add)

    def _handle_pending_requests(self):
        rem_keys = []
        # Serve pending requests.
        with self.pen_req_lock:
            for key in self.pending_req:
                for req_key, (req, req_id, meta,
                              logger) in self.pending_req[key].items():
                    if self.path_resolution(CtrlPayload(PathMgmt(req),
                                                        req_id=req_id),
                                            meta,
                                            new_request=False,
                                            logger=logger):
                        meta.close()
                        del self.pending_req[key][req_key]
                if not self.pending_req[key]:
                    rem_keys.append(key)
            for key in rem_keys:
                del self.pending_req[key]

    def _handle_paths_from_zk(self, raw_entries):
        """
        Handles cached paths through ZK, passed as a list.
        """
        for raw in raw_entries:
            recs = PathSegmentRecords.from_raw(raw)
            for type_, pcb in recs.iter_pcbs():
                seg_meta = PathSegMeta(pcb,
                                       self.continue_seg_processing,
                                       type_=type_,
                                       params={'from_zk': True})
                self._process_path_seg(seg_meta)
        if raw_entries:
            logging.debug("Processed %s segments from ZK", len(raw_entries))

    def handle_path_reply(self, cpld, meta):
        pmgt = cpld.union
        reply = pmgt.union
        assert isinstance(reply, PathSegmentReply), type(reply)
        self._handle_seg_recs(reply.recs(), cpld.req_id, meta)

    def handle_seg_recs(self, cpld, meta):
        pmgt = cpld.union
        seg_recs = pmgt.union
        self._handle_seg_recs(seg_recs, cpld.req_id, meta)

    def _handle_seg_recs(self, seg_recs, req_id, meta):
        """
        Handles paths received from the network.
        """
        assert isinstance(seg_recs, PathSegmentRecords), type(seg_recs)
        params = self._dispatch_params(seg_recs, meta)
        # Add revocations for peer interfaces included in the path segments.
        for rev_info in seg_recs.iter_rev_infos():
            self.revocations.add(rev_info)
        # Verify pcbs and process them
        for type_, pcb in seg_recs.iter_pcbs():
            seg_meta = PathSegMeta(pcb, self.continue_seg_processing, meta,
                                   type_, params)
            self._process_path_seg(seg_meta, req_id)

    def continue_seg_processing(self, seg_meta):
        """
        For every path segment(that can be verified) received from the network
        or ZK this function gets called to continue the processing for the
        segment.
        The segment is added to pathdb and pending requests are checked.
        """
        pcb = seg_meta.seg
        logging.debug("Successfully verified PCB %s" % pcb.short_id())
        type_ = seg_meta.type
        params = seg_meta.params
        self.handle_ext(pcb)
        self._dispatch_segment_record(type_, pcb, **params)
        self._handle_pending_requests()

    def handle_ext(self, pcb):
        """
        Handle beacon extensions.
        """
        # Handle PCB extensions:
        for asm in pcb.iter_asms():
            pol = asm.routing_pol_ext()
            if pol:
                self.handle_routing_pol_ext(pol)

    def handle_routing_pol_ext(self, ext):
        # TODO(Sezer): Implement extension handling
        logging.debug("Routing policy extension: %s" % ext)

    def _dispatch_segment_record(self, type_, seg, **kwargs):
        # Check that segment does not contain a revoked interface.
        if not self._validate_segment(seg):
            return
        handle_map = {
            PST.UP: self._handle_up_segment_record,
            PST.CORE: self._handle_core_segment_record,
            PST.DOWN: self._handle_down_segment_record,
        }
        handle_map[type_](seg, **kwargs)

    def _validate_segment(self, seg):
        """
        Check segment for revoked upstream/downstream interfaces.

        :param seg: The PathSegment object.
        :return: False, if the path segment contains a revoked upstream/
            downstream interface (not peer). True otherwise.
        """
        for asm in seg.iter_asms():
            pcbm = asm.pcbm(0)
            for if_id in [pcbm.hof().ingress_if, pcbm.hof().egress_if]:
                rev_info = self.revocations.get((asm.isd_as(), if_id))
                if rev_info:
                    logging.debug(
                        "Found revoked interface (%d, %s) in segment %s." %
                        (rev_info.p.ifID, rev_info.isd_as(), seg.short_desc()))
                    return False
        return True

    def _dispatch_params(self, pld, meta):
        return {}

    def _propagate_and_sync(self):
        self._share_via_zk()
        self._share_revs_via_zk()

    def _gen_prop_recs(self, container, limit=PROP_LIMIT):
        count = 0
        pcbs = defaultdict(list)
        while container:
            try:
                _, (type_, pcb) = container.popitem(last=False)
            except KeyError:
                continue
            count += 1
            pcbs[type_].append(pcb.copy())
            if count >= limit:
                yield (pcbs)
                count = 0
                pcbs = defaultdict(list)
        if pcbs:
            yield (pcbs)

    @abstractmethod
    def path_resolution(self,
                        path_request,
                        meta,
                        new_request=True,
                        logger=None):
        """
        Handles all types of path request.
        """
        raise NotImplementedError

    def _handle_waiting_targets(self, pcb):
        """
        Handle any queries that are waiting for a path to any core AS in an ISD.
        """
        dst_ia = pcb.first_ia()
        if not self.is_core_as(dst_ia):
            logging.warning("Invalid waiting target, not a core AS: %s",
                            dst_ia)
            return
        self._send_waiting_queries(dst_ia[0], pcb)

    def _send_waiting_queries(self, dst_isd, pcb):
        targets = self.waiting_targets[dst_isd]
        if not targets:
            return
        path = pcb.get_path(reverse_direction=True)
        src_ia = pcb.first_ia()
        while targets:
            (seg_req, logger) = targets.pop(0)
            meta = self._build_meta(ia=src_ia,
                                    path=path,
                                    host=SVCType.PS_A,
                                    reuse=True)
            req_id = mk_ctrl_req_id()
            self.send_meta(CtrlPayload(PathMgmt(seg_req), req_id=req_id), meta)
            logger.info("Waiting request (%s) sent to %s via %s [id: %016x]",
                        seg_req.short_desc(), meta, pcb.short_desc(), req_id)

    def _share_via_zk(self):
        if not self._segs_to_zk:
            return
        logging.info("Sharing %d segment(s) via ZK", len(self._segs_to_zk))
        for pcb_dict in self._gen_prop_recs(self._segs_to_zk,
                                            limit=self.ZK_SHARE_LIMIT):
            seg_recs = PathSegmentRecords.from_values(pcb_dict)
            self._zk_write(seg_recs.pack())

    def _share_revs_via_zk(self):
        if not self._revs_to_zk:
            return
        logging.info("Sharing %d revocation(s) via ZK", len(self._revs_to_zk))
        while self._revs_to_zk:
            try:
                data = self._revs_to_zk.popitem(last=False)[1]
            except KeyError:
                continue
            self._zk_write_rev(data)

    def _zk_write(self, data):
        hash_ = crypto_hash(data).hex()
        try:
            self.path_cache.store("%s-%s" % (hash_, SCIONTime.get_time()),
                                  data)
        except ZkNoConnection:
            logging.warning("Unable to store segment(s) in shared path: "
                            "no connection to ZK")

    def _zk_write_rev(self, data):
        hash_ = crypto_hash(data).hex()
        try:
            self.rev_cache.store("%s-%s" % (hash_, SCIONTime.get_time()), data)
        except ZkNoConnection:
            logging.warning("Unable to store revocation(s) in shared path: "
                            "no connection to ZK")

    def _init_request_logger(self):
        """
        Initializes the request logger.
        """
        self._request_logger = logging.getLogger("RequestLogger")
        # Create new formatter to include the request in the log.
        formatter = formatter = Rfc3339Formatter(
            "%(asctime)s [%(levelname)s] (%(threadName)s) %(message)s "
            "{id=%(id)s, from=%(from)s}")
        add_formatter('RequestLogger', formatter)

    def get_request_logger(self, req_id, meta):
        """
        Returns a logger adapter for a request.
        """
        # Create a logger for the request to log with context.
        return logging.LoggerAdapter(self._request_logger, {
            "id": req_id,
            "from": str(meta)
        })

    def _init_metrics(self):
        super()._init_metrics()
        REQS_TOTAL.labels(**self._labels).inc(0)
        REQS_PENDING.labels(**self._labels).set(0)
        SEGS_TO_ZK.labels(**self._labels).set(0)
        REVS_TO_ZK.labels(**self._labels).set(0)
        HT_ROOT_MAPPTINGS.labels(**self._labels).set(0)
        IS_MASTER.labels(**self._labels).set(0)

    def _update_metrics(self):
        """
        Updates all Gauge metrics. Subclass can update their own metrics but must
        call the superclass' implementation.
        """
        if not self._labels:
            return
        # Update pending requests metric.
        # XXX(shitz): This could become a performance problem should there ever be
        # a large amount of pending requests (>100'000).
        total_pending = 0
        with self.pen_req_lock:
            for reqs in self.pending_req.values():
                total_pending += len(reqs)
        REQS_PENDING.labels(**self._labels).set(total_pending)
        # Update SEGS_TO_ZK and REVS_TO_ZK metrics.
        SEGS_TO_ZK.labels(**self._labels).set(len(self._segs_to_zk))
        REVS_TO_ZK.labels(**self._labels).set(len(self._revs_to_zk))
        # Update HT_ROOT_MAPPTINGS metric.
        HT_ROOT_MAPPTINGS.labels(**self._labels).set(len(self.htroot_if2seg))
        # Update IS_MASTER metric.
        IS_MASTER.labels(**self._labels).set(int(self.zk.have_lock()))

    def run(self):
        """
        Run an instance of the Path Server.
        """
        threading.Thread(target=thread_safety_net,
                         args=(self.worker, ),
                         name="PS.worker",
                         daemon=True).start()
        threading.Thread(target=thread_safety_net,
                         args=(self._check_trc_cert_reqs, ),
                         name="Elem.check_trc_cert_reqs",
                         daemon=True).start()
        super().run()
示例#3
0
文件: base.py 项目: xabarass/scion
class BeaconServer(SCIONElement, metaclass=ABCMeta):
    """
    The SCION PathConstructionBeacon Server.
    """
    SERVICE_TYPE = ServiceType.BS
    # ZK path for incoming PCBs
    ZK_PCB_CACHE_PATH = "pcb_cache"
    # ZK path for revocations.
    ZK_REVOCATIONS_PATH = "rev_cache"
    # Time revocation objects are cached in memory (in seconds).
    ZK_REV_OBJ_MAX_AGE = MIN_REVOCATION_TTL
    # Revocation TTL
    REVOCATION_TTL = MIN_REVOCATION_TTL
    # Revocation Overlapping (seconds)
    REVOCATION_OVERLAP = 2
    # Interval to checked for timed out interfaces.
    IF_TIMEOUT_INTERVAL = 1
    # Interval to send keep-alive msgs
    IFID_INTERVAL = 1
    # Interval between two consecutive requests (in seconds).
    CERT_REQ_RATE = 10

    def __init__(self,
                 server_id,
                 conf_dir,
                 spki_cache_dir=GEN_CACHE_PATH,
                 prom_export=None,
                 sciond_path=None):
        """
        :param str server_id: server identifier.
        :param str conf_dir: configuration directory.
        :param str prom_export: prometheus export address.
        :param str sciond_path: path to sciond socket.
        """
        super().__init__(server_id,
                         conf_dir,
                         spki_cache_dir=spki_cache_dir,
                         prom_export=prom_export,
                         sciond_path=sciond_path)
        self.config = self._load_as_conf()
        self.master_key_0 = get_master_key(self.conf_dir, MASTER_KEY_0)
        self.master_key_1 = get_master_key(self.conf_dir, MASTER_KEY_1)
        # TODO: add 2 policies
        self.path_policy = PathPolicy.from_file(
            os.path.join(conf_dir, PATH_POLICY_FILE))
        self.signing_key = get_sig_key(self.conf_dir)
        self.of_gen_key = kdf(self.master_key_0, b"Derive OF Key")
        # Amount of time units a HOF is valid (time unit is EXP_TIME_UNIT).
        self.default_hof_exp_time = int(self.config.segment_ttl /
                                        EXP_TIME_UNIT)
        self.ifid_state = {}
        for ifid in self.ifid2br:
            self.ifid_state[ifid] = InterfaceState()
        self.ifid_state_lock = RLock()
        self.if_revocations = {}
        self.CTRL_PLD_CLASS_MAP = {
            PayloadClass.PCB: {
                PayloadClass.PCB: self.handle_pcb
            },
            PayloadClass.IFID: {
                PayloadClass.IFID: self.handle_ifid_packet
            },
            PayloadClass.CERT: {
                CertMgmtType.CERT_CHAIN_REQ: self.process_cert_chain_request,
                CertMgmtType.CERT_CHAIN_REPLY: self.process_cert_chain_reply,
                CertMgmtType.TRC_REPLY: self.process_trc_reply,
                CertMgmtType.TRC_REQ: self.process_trc_request,
            },
            PayloadClass.PATH: {
                PMT.IFSTATE_REQ: self._handle_ifstate_request,
                PMT.REVOCATION: self._handle_revocation,
            },
        }
        self.SCMP_PLD_CLASS_MAP = {
            SCMPClass.PATH: {
                SCMPPathClass.REVOKED_IF: self._handle_scmp_revocation,
            },
        }

        zkid = ZkID.from_values(self.addr.isd_as, self.id,
                                [(self.addr.host, self._port)]).pack()
        self.zk = Zookeeper(self.addr.isd_as, self.SERVICE_TYPE, zkid,
                            self.topology.zookeepers)
        self.zk.retry("Joining party", self.zk.party_setup)
        self.pcb_cache = ZkSharedCache(self.zk, self.ZK_PCB_CACHE_PATH,
                                       self._handle_pcbs_from_zk)
        self.revobjs_cache = ZkSharedCache(self.zk, self.ZK_REVOCATIONS_PATH,
                                           self.process_rev_objects)
        self.local_rev_cache = RevCache()
        self._rev_seg_lock = RLock()

    def propagate_downstream_pcb(self, pcb):
        """
        Propagates the beacon to all children.

        :param pcb: path segment.
        :type pcb: PathSegment
        """
        propagated_pcbs = defaultdict(list)
        prop_cnt = 0
        for intf in self.topology.child_interfaces:
            if not intf.to_if_id:
                continue
            new_pcb, meta = self._mk_prop_pcb_meta(pcb.copy(), intf.isd_as,
                                                   intf.if_id)
            if not new_pcb:
                continue
            self.send_meta(CtrlPayload(new_pcb.pcb()), meta)
            propagated_pcbs[(intf.isd_as, intf.if_id)].append(pcb.short_id())
            prop_cnt += 1
        if self._labels:
            BEACONS_PROPAGATED.labels(**self._labels,
                                      type="down").inc(prop_cnt)
        return propagated_pcbs

    def _mk_prop_pcb_meta(self, pcb, dst_ia, egress_if):
        ts = pcb.get_timestamp()
        asm = self._create_asm(pcb.ifID, egress_if, ts, pcb.last_hof())
        if not asm:
            return None, None
        pcb.add_asm(asm, ProtoSignType.ED25519, self.addr.isd_as.pack())
        pcb.sign(self.signing_key)
        one_hop_path = self._create_one_hop_path(egress_if)
        return pcb, self._build_meta(ia=dst_ia,
                                     host=SVCType.BS_A,
                                     path=one_hop_path,
                                     one_hop=True)

    def _create_one_hop_path(self, egress_if):
        ts = int(SCIONTime.get_time())
        info = InfoOpaqueField.from_values(ts, self.addr.isd_as[0], hops=2)
        hf1 = HopOpaqueField.from_values(OneHopPathExt.HOF_EXP_TIME, 0,
                                         egress_if)
        hf1.set_mac(self.of_gen_key, ts, None)
        # Return a path where second HF is empty.
        return SCIONPath.from_values(info, [hf1, HopOpaqueField()])

    def hof_exp_time(self, ts):
        """
        Return the ExpTime based on IF timestamp and the certificate chain/TRC.
        The certificate chain must be valid for the entire HOF lifetime.

        :param int ts: IF timestamp
        :return: HF ExpTime
        :rtype: int
        """
        cert_exp = self._get_my_cert().as_cert.expiration_time
        max_exp_time = int((cert_exp - ts) / EXP_TIME_UNIT)
        return min(max_exp_time, self.default_hof_exp_time)

    def _mk_if_info(self, if_id):
        """
        Small helper method to make it easier to deal with ingress/egress
        interface being 0 while building ASMarkings.
        """
        d = {"remote_ia": ISD_AS.from_values(0, 0), "remote_if": 0, "mtu": 0}
        if not if_id:
            return d
        br = self.ifid2br[if_id]
        d["remote_ia"] = br.interfaces[if_id].isd_as
        d["remote_if"] = br.interfaces[if_id].to_if_id
        d["mtu"] = br.interfaces[if_id].mtu
        return d

    @abstractmethod
    def handle_pcbs_propagation(self):
        """
        Main loop to propagate received beacons.
        """
        raise NotImplementedError

    def _log_propagations(self, propagated_pcbs):
        for (isd_as, if_id), pcbs in propagated_pcbs.items():
            logging.debug("Propagated %d PCBs to %s via %s (%s)", len(pcbs),
                          isd_as, if_id, ", ".join(pcbs))

    def _handle_pcbs_from_zk(self, pcbs):
        """
        Handles cached pcbs through ZK, passed as a list.
        """
        for pcb in pcbs:
            try:
                pcb = PCB.from_raw(pcb)
            except SCIONParseError as e:
                logging.error("Unable to parse raw pcb: %s", e)
                continue
            self.handle_pcb(CtrlPayload(pcb))
        if pcbs:
            logging.debug("Processed %s PCBs from ZK", len(pcbs))

    def handle_pcb(self, cpld, meta=None):
        """
        Handles pcbs received from the network.
        """
        pcb = cpld.union
        assert isinstance(pcb, PCB), type(pcb)
        pcb = pcb.pseg()
        if meta:
            pcb.ifID = meta.path.get_hof().ingress_if
        try:
            self.path_policy.check_filters(pcb)
        except SCIONPathPolicyViolated as e:
            logging.debug("Segment dropped due to path policy: %s\n%s" %
                          (e, pcb.short_desc()))
            return
        if not self._filter_pcb(pcb):
            logging.debug("Segment dropped due to looping: %s" %
                          pcb.short_desc())
            return
        seg_meta = PathSegMeta(pcb, self.continue_seg_processing, meta)
        self._process_path_seg(seg_meta, cpld.req_id)

    def continue_seg_processing(self, seg_meta):
        """
        For every verified pcb received from the network or ZK
        this function gets called to continue the processing for the pcb.
        """
        pseg = seg_meta.seg
        logging.debug("Successfully verified PCB %s", pseg.short_id())
        if seg_meta.meta:
            # Segment was received from network, not from zk. Share segment
            # with other beacon servers in this AS.
            entry_name = "%s-%s" % (pseg.get_hops_hash(hex=True), time.time())
            try:
                self.pcb_cache.store(entry_name, pseg.pcb().copy().pack())
            except ZkNoConnection:
                logging.error("Unable to store PCB in shared cache: "
                              "no connection to ZK")
        self.handle_ext(pseg)
        self._handle_verified_beacon(pseg)

    def _filter_pcb(self, pcb, dst_ia=None):
        return True

    def handle_ext(self, pcb):
        """
        Handle beacon extensions.
        """
        # Handle PCB extensions
        for asm in pcb.iter_asms():
            pol = asm.routing_pol_ext()
            if pol:
                self.handle_routing_pol_ext(pol)

    def handle_routing_pol_ext(self, ext):
        # TODO(Sezer): Implement routing policy extension handling
        logging.debug("Routing policy extension: %s" % ext)

    @abstractmethod
    def register_segments(self):
        """
        Registers paths according to the received beacons.
        """
        raise NotImplementedError

    def _log_registrations(self, registrations, seg_type):
        reg_cnt = 0
        for (dst_meta, dst_type), pcbs in registrations.items():
            reg_cnt += len(pcbs)
            logging.debug("Registered %d %s-segments @ %s:%s (%s)", len(pcbs),
                          seg_type, dst_type.upper(), dst_meta,
                          ", ".join(pcbs))
        if self._labels:
            SEGMENTS_REGISTERED.labels(**self._labels,
                                       type=seg_type).inc(reg_cnt)

    def _create_asm(self, in_if, out_if, ts, prev_hof):
        pcbms = list(self._create_pcbms(in_if, out_if, ts, prev_hof))
        if not pcbms:
            return None
        chain = self._get_my_cert()
        _, cert_ver = chain.get_leaf_isd_as_ver()
        return ASMarking.from_values(self.addr.isd_as,
                                     self._get_my_trc().version, cert_ver,
                                     pcbms, self.topology.mtu)

    def _create_pcbms(self, in_if, out_if, ts, prev_hof):
        up_pcbm = self._create_pcbm(in_if, out_if, ts, prev_hof)
        if not up_pcbm:
            return
        yield up_pcbm
        for intf in sorted(self.topology.peer_interfaces):
            in_if = intf.if_id
            with self.ifid_state_lock:
                if (not self.ifid_state[in_if].is_active()
                        and not self._quiet_startup()):
                    continue
            peer_pcbm = self._create_pcbm(in_if,
                                          out_if,
                                          ts,
                                          up_pcbm.hof(),
                                          xover=True)
            if peer_pcbm:
                yield peer_pcbm

    def _create_pcbm(self, in_if, out_if, ts, prev_hof, xover=False):
        in_info = self._mk_if_info(in_if)
        if in_info["remote_ia"].int() and not in_info["remote_if"]:
            return None
        out_info = self._mk_if_info(out_if)
        if out_info["remote_ia"].int() and not out_info["remote_if"]:
            return None
        exp_time = self.hof_exp_time(ts)
        if exp_time <= 0:
            logging.error("Invalid hop field expiration time value: %s",
                          exp_time)
            return None
        hof = HopOpaqueField.from_values(exp_time, in_if, out_if, xover=xover)
        hof.set_mac(self.of_gen_key, ts, prev_hof)
        return PCBMarking.from_values(in_info["remote_ia"],
                                      in_info["remote_if"], in_info["mtu"],
                                      out_info["remote_ia"],
                                      out_info["remote_if"], hof)

    def _terminate_pcb(self, pcb):
        """
        Copies a PCB, terminates it and adds the segment ID.

        Terminating a PCB means adding a opaque field with the egress IF set
        to 0, i.e., there is no AS to forward a packet containing this path
        segment to.
        """
        pcb = pcb.copy()
        asm = self._create_asm(pcb.ifID, 0, pcb.get_timestamp(),
                               pcb.last_hof())
        if not asm:
            return None
        pcb.add_asm(asm, ProtoSignType.ED25519, self.addr.isd_as.pack())
        return pcb

    def handle_ifid_packet(self, cpld, meta):
        """
        Update the interface state for the corresponding interface.

        :param pld: The IFIDPayload.
        :type pld: IFIDPayload
        """
        pld = cpld.union
        assert isinstance(pld, IFIDPayload), type(pld)
        ifid = meta.pkt.path.get_hof().ingress_if
        with self.ifid_state_lock:
            if ifid not in self.ifid_state:
                raise SCIONKeyError("Invalid IF %d in IFIDPayload" % ifid)
            br = self.ifid2br[ifid]
            br.interfaces[ifid].to_if_id = pld.p.origIF
            prev_state = self.ifid_state[ifid].update()
            if prev_state == InterfaceState.INACTIVE:
                logging.info("IF %d activated.", ifid)
            elif prev_state in [
                    InterfaceState.TIMED_OUT, InterfaceState.REVOKED
            ]:
                logging.info("IF %d came back up.", ifid)
            if prev_state != InterfaceState.ACTIVE:
                if self.zk.have_lock():
                    # Inform BRs about the interface coming up.
                    metas = []
                    for br in self.topology.border_routers:
                        br_addr, br_port = br.ctrl_addrs.public
                        metas.append(
                            UDPMetadata.from_values(host=br_addr,
                                                    port=br_port))
                    info = IFStateInfo.from_values(ifid, True)
                    self._send_ifstate_update([info], metas)

    def run(self):
        """
        Run an instance of the Beacon Server.
        """
        threading.Thread(target=thread_safety_net,
                         args=(self.worker, ),
                         name="BS.worker",
                         daemon=True).start()
        # https://github.com/scionproto/scion/issues/308:
        threading.Thread(target=thread_safety_net,
                         args=(self._send_ifid_updates, ),
                         name="BS._send_if_updates",
                         daemon=True).start()
        threading.Thread(target=thread_safety_net,
                         args=(self._handle_if_timeouts, ),
                         name="BS._handle_if_timeouts",
                         daemon=True).start()
        threading.Thread(target=thread_safety_net,
                         args=(self._check_trc_cert_reqs, ),
                         name="Elem.check_trc_cert_reqs",
                         daemon=True).start()
        threading.Thread(target=thread_safety_net,
                         args=(self._check_local_cert, ),
                         name="BS._check_local_cert",
                         daemon=True).start()
        super().run()

    def worker(self):
        """
        Worker thread that takes care of reading shared PCBs from ZK, and
        propagating PCBS/registering paths when master.
        """
        last_propagation = last_registration = 0
        worker_cycle = 1.0
        start = time.time()
        while self.run_flag.is_set():
            sleep_interval(start, worker_cycle, "BS.worker cycle",
                           self._quiet_startup())
            start = time.time()
            # Update IS_MASTER metric.
            if self._labels:
                IS_MASTER.labels(**self._labels).set(int(self.zk.have_lock()))
            try:
                self.zk.wait_connected()
                self.pcb_cache.process()
                self.revobjs_cache.process()
                self.handle_rev_objs()

                ret = self.zk.get_lock(lock_timeout=0, conn_timeout=0)
                if not ret:  # Failed to get the lock
                    continue
                elif ret == ZK_LOCK_SUCCESS:
                    logging.info("Became master")
                    self._became_master()
                self.pcb_cache.expire(self.config.propagation_time * 10)
                self.revobjs_cache.expire(self.ZK_REV_OBJ_MAX_AGE)
            except ZkNoConnection:
                continue
            now = time.time()
            if now - last_propagation >= self.config.propagation_time:
                self.handle_pcbs_propagation()
                last_propagation = now
            if (self.config.registers_paths and
                    now - last_registration >= self.config.registration_time):
                try:
                    self.register_segments()
                except SCIONKeyError as e:
                    logging.error("Error while registering segments: %s", e)
                    pass
                last_registration = now

    def _became_master(self):
        """
        Called when a BS becomes the new master. Resets some state that will be
        rebuilt over time.
        """
        # Reset all timed-out and revoked interfaces to inactive.
        with self.ifid_state_lock:
            for (_, ifstate) in self.ifid_state.items():
                if not ifstate.is_active():
                    ifstate.reset()

    def _get_my_trc(self):
        return self.trust_store.get_trc(self.addr.isd_as[0])

    def _get_my_cert(self):
        return self.trust_store.get_cert(self.addr.isd_as)

    @abstractmethod
    def _handle_verified_beacon(self, pcb):
        """
        Once a beacon has been verified, place it into the right containers.

        :param pcb: verified path segment.
        :type pcb: PathSegment
        """
        raise NotImplementedError

    def process_rev_objects(self, rev_infos):
        """
        Processes revocation infos stored in Zookeeper.
        """
        with self._rev_seg_lock:
            for raw in rev_infos:
                try:
                    srev_info = SignedRevInfo.from_raw(raw)
                except SCIONParseError as e:
                    logging.error("Error parsing revocation info from ZK: %s",
                                  e)
                    continue
                self.check_revocation(
                    srev_info,
                    lambda x: lambda: self.local_rev_cache.add(srev_info)
                    if not x else False)

    def _issue_revocations(self, revoked_ifs):
        """
        Store a RevocationInfo in ZK and send a revocation to all BRs.

        :param list revoked_ifs: A list of interfaces that needs to be revoked.
        """
        # Only the master BS issues revocations.
        if not self.zk.have_lock():
            return
        # Process revoked interfaces.
        infos = []
        for if_id in revoked_ifs:
            br = self.ifid2br[if_id]
            rev_info = RevocationInfo.from_values(
                self.addr.isd_as, if_id, br.interfaces[if_id].link_type,
                int(time.time()), self.REVOCATION_TTL)
            logging.info("Issuing revocation: %s", rev_info.short_desc())
            if self._labels:
                REVOCATIONS_ISSUED.labels(**self._labels).inc()
            chain = self._get_my_cert()
            _, cert_ver = chain.get_leaf_isd_as_ver()
            src = DefaultSignSrc.from_values(
                rev_info.isd_as(), cert_ver,
                self._get_my_trc().version).pack()
            srev_info = SignedRevInfo.from_values(rev_info.copy().pack(),
                                                  ProtoSignType.ED25519, src)
            srev_info.sign(self.signing_key)
            # Add to revocation cache
            self.if_revocations[if_id] = srev_info
            self._process_revocation(srev_info)
            infos.append(IFStateInfo.from_values(if_id, False, srev_info))
        metas = []
        # Add all BRs.
        for br in self.topology.border_routers:
            br_addr, br_port = br.ctrl_addrs.public
            metas.append(UDPMetadata.from_values(host=br_addr, port=br_port))
        # Add local path server.
        if self.topology.path_servers:
            try:
                addr, port = self.dns_query_topo(ServiceType.PS)[0]
            except SCIONServiceLookupError:
                addr, port = None, None
            # Create a meta if there is a local path service
            if addr:
                metas.append(UDPMetadata.from_values(host=addr, port=port))
        self._send_ifstate_update(infos, metas)

    def _handle_scmp_revocation(self, pld, meta):
        srev_info = SignedRevInfo.from_raw(pld.info.srev_info)
        self._handle_revocation(CtrlPayload(PathMgmt(srev_info)), meta)

    def _handle_revocation(self, cpld, meta):
        pmgt = cpld.union
        srev_info = pmgt.union
        rev_info = srev_info.rev_info()
        assert isinstance(rev_info, RevocationInfo), type(rev_info)
        logging.debug("Received revocation from %s: %s", meta,
                      rev_info.short_desc())
        self.check_revocation(
            srev_info, lambda x: self._process_revocation(srev_info)
            if not x else False, meta)

    def handle_rev_objs(self):
        with self._rev_seg_lock:
            for srev_info in self.local_rev_cache.values():
                self._remove_revoked_pcbs(srev_info.rev_info())

    def _process_revocation(self, srev_info):
        """
        Removes PCBs containing a revoked interface and sends the revocation
        to the local PS.

        :param srev_info: The signed RevocationInfo object
        :type srev_info: SignedRevInfo
        """
        rev_info = srev_info.rev_info()
        assert isinstance(rev_info, RevocationInfo), type(rev_info)
        if_id = rev_info.p.ifID
        if not if_id:
            logging.error("Trying to revoke IF with ID 0.")
            return
        with self._rev_seg_lock:
            self.local_rev_cache.add(srev_info.copy())
        srev_info_packed = srev_info.copy().pack()
        entry_name = "%s:%s" % (hash(srev_info_packed), time.time())
        try:
            self.revobjs_cache.store(entry_name, srev_info_packed)
        except ZkNoConnection as exc:
            logging.error("Unable to store revocation in shared cache "
                          "(no ZK connection): %s" % exc)
        self._remove_revoked_pcbs(rev_info)

    @abstractmethod
    def _remove_revoked_pcbs(self, rev_info):
        """
        Removes the PCBs containing the revoked interface.

        :param rev_info: The RevocationInfo object.
        :type rev_info: RevocationInfo
        """
        raise NotImplementedError

    def _pcb_list_to_remove(self, candidates, rev_info):
        """
        Calculates the list of PCBs to remove.
        Called by _remove_revoked_pcbs.

        :param candidates: Candidate PCBs.
        :type candidates: List
        :param rev_info: The RevocationInfo object.
        :type rev_info: RevocationInfo
        """
        to_remove = []
        if not rev_info.active():
            return to_remove
        processed = set()
        for cand in candidates:
            if cand.id in processed:
                continue
            processed.add(cand.id)

            # If the interface on which we received the PCB is
            # revoked, then the corresponding pcb needs to be removed.
            if (self.addr.isd_as == rev_info.isd_as()
                    and cand.pcb.ifID == rev_info.p.ifID):
                to_remove.append(cand.id)

            for asm in cand.pcb.iter_asms():
                if self._check_revocation_for_asm(rev_info, asm, False):
                    to_remove.append(cand.id)

        return to_remove

    def _handle_if_timeouts(self):
        """
        Periodically checks each interface state and issues an IF revocation, if
        no keep-alive message was received for IFID_TOUT.
        """
        while self.run_flag.is_set():
            start_time = time.time()
            with self.ifid_state_lock:
                to_revoke = []
                for (ifid, if_state) in self.ifid_state.items():
                    if self._labels:
                        metric = IF_STATE.labels(ifid=ifid, **self._labels)
                        if if_state.is_active():
                            metric.set(0)
                        elif if_state.is_revoked():
                            metric.set(1)
                        else:
                            metric.set(2)
                    if not if_state.is_expired():
                        # Interface hasn't timed out
                        self.if_revocations.pop(ifid, None)
                        continue
                    srev_info = self.if_revocations.get(ifid, None)
                    if if_state.is_revoked() and srev_info:
                        # Interface is revoked until the revocation time plus the revocation TTL,
                        # we want to issue a new revocation REVOCATION_OVERLAP seconds
                        # before it is expired
                        rev_info = srev_info.rev_info()
                        if (rev_info.p.timestamp + rev_info.p.ttl -
                                self.REVOCATION_OVERLAP > start_time):
                            # Interface has already been revoked within the REVOCATION_TTL -
                            # REVOCATION_OVERLAP period
                            continue
                    if not if_state.is_revoked():
                        logging.info("IF %d went down.", ifid)
                    to_revoke.append(ifid)
                    if_state.revoke_if_expired()
                if to_revoke:
                    self._issue_revocations(to_revoke)
            sleep_interval(start_time, self.IF_TIMEOUT_INTERVAL,
                           "Handle IF timeouts")

    def _handle_ifstate_request(self, cpld, meta):
        # Only master replies to ifstate requests.
        pmgt = cpld.union
        req = pmgt.union
        assert isinstance(req, IFStateRequest), type(req)
        if not self.zk.have_lock():
            return
        with self.ifid_state_lock:
            infos = []
            for (ifid, state) in self.ifid_state.items():
                # Don't include inactive interfaces in update.
                if state.is_inactive():
                    continue
                srev_info = None
                if state.is_revoked():
                    srev_info = self.if_revocations.get(ifid, None)
                    if not srev_info:
                        logging.warning(
                            "No revocation in cache for revoked IFID: %s",
                            ifid)
                        continue
                infos.append(
                    IFStateInfo.from_values(ifid, state.is_active(),
                                            srev_info))
            if not infos and not self._quiet_startup():
                logging.warning(
                    "No IF state info to put in IFState update for %s.", meta)
                return
        self._send_ifstate_update(infos, [meta])

    def _send_ifstate_update(self, state_infos, server_metas):
        payload = CtrlPayload(PathMgmt(
            IFStatePayload.from_values(state_infos)))
        for meta in server_metas:
            logging.debug("IFState update to %s:%s", meta.host, meta.port)
            self.send_meta(payload.copy(), meta)

    def _send_ifid_updates(self):
        start = time.time()
        while self.run_flag.is_set():
            sleep_interval(start, self.IFID_INTERVAL,
                           "BS._send_ifid_updates cycle")
            start = time.time()

            # only master sends keep-alive messages
            if not self.zk.have_lock():
                continue

            # send keep-alives on all known BR interfaces
            for ifid in self.ifid2br:
                br = self.ifid2br[ifid]
                br_addr, br_port = br.int_addrs.public
                one_hop_path = self._create_one_hop_path(ifid)
                meta = self._build_meta(ia=br.interfaces[ifid].isd_as,
                                        host=SVCType.BS_M,
                                        path=one_hop_path,
                                        one_hop=True)
                self.send_meta(CtrlPayload(IFIDPayload.from_values(ifid)),
                               meta, (br_addr, br_port))

    def _check_local_cert(self):
        while self.run_flag.is_set():
            chain = self._get_my_cert()
            exp = min(chain.as_cert.expiration_time,
                      chain.core_as_cert.expiration_time)
            diff = exp - int(time.time())
            if diff > self.config.segment_ttl:
                time.sleep(diff - self.config.segment_ttl)
                continue
            cs_meta = self._get_cs()
            req = CertChainRequest.from_values(self.addr.isd_as,
                                               chain.as_cert.version + 1,
                                               cache_only=True)
            logging.info("Request new certificate chain. Req: %s", req)
            self.send_meta(CtrlPayload(CertMgmt(req)), cs_meta)
            cs_meta.close()
            time.sleep(self.CERT_REQ_RATE)

    def _init_metrics(self):
        super()._init_metrics()
        for type_ in ("core", "up", "down"):
            BEACONS_PROPAGATED.labels(**self._labels, type=type_).inc(0)
            SEGMENTS_REGISTERED.labels(**self._labels, type=type_).inc(0)
        REVOCATIONS_ISSUED.labels(**self._labels).inc(0)
        IS_MASTER.labels(**self._labels).set(0)
示例#4
0
class BeaconServer(SCIONElement, metaclass=ABCMeta):
    """
    The SCION PathConstructionBeacon Server.

    Attributes:
        if2rev_tokens: Contains the currently used revocation token
            hash-chain for each interface.
    """
    SERVICE_TYPE = BEACON_SERVICE
    # Amount of time units a HOF is valid (time unit is EXP_TIME_UNIT).
    HOF_EXP_TIME = 63
    # ZK path for incoming PCBs
    ZK_PCB_CACHE_PATH = "pcb_cache"
    # ZK path for revocations.
    ZK_REVOCATIONS_PATH = "rev_cache"
    # Time revocation objects are cached in memory (in seconds).
    ZK_REV_OBJ_MAX_AGE = HASHTREE_EPOCH_TIME
    # Interval to checked for timed out interfaces.
    IF_TIMEOUT_INTERVAL = 1

    def __init__(self, server_id, conf_dir):
        """
        :param str server_id: server identifier.
        :param str conf_dir: configuration directory.
        """
        super().__init__(server_id, conf_dir)
        # TODO: add 2 policies
        self.path_policy = PathPolicy.from_file(
            os.path.join(conf_dir, PATH_POLICY_FILE))
        self.signing_key = get_sig_key(self.conf_dir)
        self.of_gen_key = kdf(self.config.master_as_key, b"Derive OF Key")
        self.hashtree_gen_key = kdf(self.config.master_as_key,
                                    b"Derive hashtree Key")
        logging.info(self.config.__dict__)
        self._hash_tree = None
        self._hash_tree_lock = Lock()
        self._next_tree = None
        self._init_hash_tree()
        self.ifid_state = {}
        for ifid in self.ifid2br:
            self.ifid_state[ifid] = InterfaceState()
        self.ifid_state_lock = RLock()
        self.CTRL_PLD_CLASS_MAP = {
            PayloadClass.PCB: {
                None: self.handle_pcb
            },
            PayloadClass.IFID: {
                None: self.handle_ifid_packet
            },
            PayloadClass.CERT: {
                CertMgmtType.CERT_CHAIN_REQ: self.process_cert_chain_request,
                CertMgmtType.CERT_CHAIN_REPLY: self.process_cert_chain_reply,
                CertMgmtType.TRC_REPLY: self.process_trc_reply,
                CertMgmtType.TRC_REQ: self.process_trc_request,
            },
            PayloadClass.PATH: {
                PMT.IFSTATE_REQ: self._handle_ifstate_request,
                PMT.REVOCATION: self._handle_revocation,
            },
        }
        self.SCMP_PLD_CLASS_MAP = {
            SCMPClass.PATH: {
                SCMPPathClass.REVOKED_IF: self._handle_scmp_revocation,
            },
        }

        zkid = ZkID.from_values(self.addr.isd_as, self.id,
                                [(self.addr.host, self._port)]).pack()
        self.zk = Zookeeper(self.addr.isd_as, BEACON_SERVICE, zkid,
                            self.topology.zookeepers)
        self.zk.retry("Joining party", self.zk.party_setup)
        self.pcb_cache = ZkSharedCache(self.zk, self.ZK_PCB_CACHE_PATH,
                                       self._handle_pcbs_from_zk)
        self.revobjs_cache = ZkSharedCache(self.zk, self.ZK_REVOCATIONS_PATH,
                                           self.process_rev_objects)
        self.local_rev_cache = ExpiringDict(
            1000, HASHTREE_EPOCH_TIME + HASHTREE_EPOCH_TOLERANCE)
        self._rev_seg_lock = RLock()

    def _init_hash_tree(self):
        ifs = list(self.ifid2br.keys())
        self._hash_tree = ConnectedHashTree(self.addr.isd_as, ifs,
                                            self.hashtree_gen_key,
                                            HashType.SHA256)

    def _get_ht_proof(self, if_id):
        with self._hash_tree_lock:
            return self._hash_tree.get_proof(if_id)

    def _get_ht_root(self):
        with self._hash_tree_lock:
            return self._hash_tree.get_root()

    def propagate_downstream_pcb(self, pcb):
        """
        Propagates the beacon to all children.

        :param pcb: path segment.
        :type pcb: PathSegment
        """
        propagated_pcbs = defaultdict(list)
        for intf in self.topology.child_interfaces:
            if not intf.to_if_id:
                continue
            new_pcb, meta = self._mk_prop_pcb_meta(pcb.copy(), intf.isd_as,
                                                   intf.if_id)
            if not new_pcb:
                continue
            self.send_meta(new_pcb, meta)
            propagated_pcbs[(intf.isd_as, intf.if_id)].append(pcb.short_id())
        return propagated_pcbs

    def _mk_prop_pcb_meta(self, pcb, dst_ia, egress_if):
        ts = pcb.get_timestamp()
        asm = self._create_asm(pcb.p.ifID, egress_if, ts, pcb.last_hof())
        if not asm:
            return None, None
        pcb.add_asm(asm)
        pcb.sign(self.signing_key)
        one_hop_path = self._create_one_hop_path(egress_if)
        return pcb, self._build_meta(ia=dst_ia,
                                     host=SVCType.BS_A,
                                     path=one_hop_path,
                                     one_hop=True)

    def _create_one_hop_path(self, egress_if):
        ts = int(SCIONTime.get_time())
        info = InfoOpaqueField.from_values(ts, self.addr.isd_as[0], hops=2)
        hf1 = HopOpaqueField.from_values(self.HOF_EXP_TIME, 0, egress_if)
        hf1.set_mac(self.of_gen_key, ts, None)
        # Return a path where second HF is empty.
        return SCIONPath.from_values(info, [hf1, HopOpaqueField()])

    def _mk_if_info(self, if_id):
        """
        Small helper method to make it easier to deal with ingress/egress
        interface being 0 while building ASMarkings.
        """
        d = {"remote_ia": ISD_AS.from_values(0, 0), "remote_if": 0, "mtu": 0}
        if not if_id:
            return d
        br = self.ifid2br[if_id]
        d["remote_ia"] = br.interfaces[if_id].isd_as
        d["remote_if"] = br.interfaces[if_id].to_if_id
        d["mtu"] = br.interfaces[if_id].mtu
        return d

    @abstractmethod
    def handle_pcbs_propagation(self):
        """
        Main loop to propagate received beacons.
        """
        raise NotImplementedError

    def _log_propagations(self, propagated_pcbs):
        for (isd_as, if_id), pcbs in propagated_pcbs.items():
            logging.debug("Propagated %d PCBs to %s via %s (%s)", len(pcbs),
                          isd_as, if_id, ", ".join(pcbs))

    def _handle_pcbs_from_zk(self, pcbs):
        """
        Handles cached pcbs through ZK, passed as a list.
        """
        for pcb in pcbs:
            try:
                pcb = PathSegment.from_raw(pcb)
            except SCIONParseError as e:
                logging.error("Unable to parse raw pcb: %s", e)
                continue
            self.handle_pcb(pcb)
        if pcbs:
            logging.debug("Processed %s PCBs from ZK", len(pcbs))

    def handle_pcb(self, pcb, meta=None):
        """
        Handles pcbs received from the network.
        """
        if meta:
            pcb.p.ifID = meta.path.get_hof().ingress_if
        try:
            self.path_policy.check_filters(pcb)
        except SCIONPathPolicyViolated as e:
            logging.debug("Segment dropped due to path policy: %s\n%s" %
                          (e, pcb.short_desc()))
            return
        if not self._filter_pcb(pcb):
            logging.debug("Segment dropped due to looping: %s" %
                          pcb.short_desc())
            return
        seg_meta = PathSegMeta(pcb, self.continue_seg_processing, meta)
        self._process_path_seg(seg_meta)

    def continue_seg_processing(self, seg_meta):
        """
        For every verified pcb received from the network or ZK
        this function gets called to continue the processing for the pcb.
        """
        pcb = seg_meta.seg
        logging.debug("Successfully verified PCB %s", pcb.short_id())
        if seg_meta.meta:
            # Segment was received from network, not from zk. Share segment
            # with other beacon servers in this AS.
            entry_name = "%s-%s" % (pcb.get_hops_hash(hex=True), time.time())
            try:
                self.pcb_cache.store(entry_name, pcb.copy().pack())
            except ZkNoConnection:
                logging.error("Unable to store PCB in shared cache: "
                              "no connection to ZK")
        self.handle_ext(pcb)
        self._handle_verified_beacon(pcb)

    def _filter_pcb(self, pcb, dst_ia=None):
        return True

    def handle_ext(self, pcb):
        """
        Handle beacon extensions.
        """
        # Handle PCB extensions
        if pcb.is_sibra():
            logging.debug("%s", pcb.sibra_ext)
        for asm in pcb.iter_asms():
            pol = asm.routing_pol_ext()
            if pol:
                self.handle_routing_pol_ext(pol)

    def handle_routing_pol_ext(self, ext):
        # TODO(Sezer): Implement routing policy extension handling
        logging.debug("Routing policy extension: %s" % ext)

    @abstractmethod
    def register_segments(self):
        """
        Registers paths according to the received beacons.
        """
        raise NotImplementedError

    def _log_registrations(self, registrations, seg_type):
        for (dst_meta, dst_type), pcbs in registrations.items():
            logging.debug("Registered %d %s-segments @ %s:%s (%s)", len(pcbs),
                          seg_type, dst_type.upper(), dst_meta,
                          ", ".join(pcbs))

    def _create_asm(self, in_if, out_if, ts, prev_hof):
        pcbms = list(self._create_pcbms(in_if, out_if, ts, prev_hof))
        if not pcbms:
            return None
        chain = self._get_my_cert()
        _, cert_ver = chain.get_leaf_isd_as_ver()
        return ASMarking.from_values(self.addr.isd_as,
                                     self._get_my_trc().version, cert_ver,
                                     pcbms, self._get_ht_root(),
                                     self.topology.mtu)

    def _create_pcbms(self, in_if, out_if, ts, prev_hof):
        up_pcbm = self._create_pcbm(in_if, out_if, ts, prev_hof)
        if not up_pcbm:
            return
        yield up_pcbm
        for intf in sorted(self.topology.peer_interfaces):
            in_if = intf.if_id
            with self.ifid_state_lock:
                if (not self.ifid_state[in_if].is_active()
                        and not self._quiet_startup()):
                    continue
            peer_pcbm = self._create_pcbm(in_if,
                                          out_if,
                                          ts,
                                          up_pcbm.hof(),
                                          xover=True)
            if peer_pcbm:
                yield peer_pcbm

    def _create_pcbm(self, in_if, out_if, ts, prev_hof, xover=False):
        in_info = self._mk_if_info(in_if)
        if in_info["remote_ia"].int() and not in_info["remote_if"]:
            return None
        out_info = self._mk_if_info(out_if)
        if out_info["remote_ia"].int() and not out_info["remote_if"]:
            return None
        hof = HopOpaqueField.from_values(self.HOF_EXP_TIME,
                                         in_if,
                                         out_if,
                                         xover=xover)
        hof.set_mac(self.of_gen_key, ts, prev_hof)
        return PCBMarking.from_values(in_info["remote_ia"],
                                      in_info["remote_if"], in_info["mtu"],
                                      out_info["remote_ia"],
                                      out_info["remote_if"], hof)

    def _terminate_pcb(self, pcb):
        """
        Copies a PCB, terminates it and adds the segment ID.

        Terminating a PCB means adding a opaque field with the egress IF set
        to 0, i.e., there is no AS to forward a packet containing this path
        segment to.
        """
        pcb = pcb.copy()
        asm = self._create_asm(pcb.p.ifID, 0, pcb.get_timestamp(),
                               pcb.last_hof())
        if not asm:
            return None
        pcb.add_asm(asm)
        return pcb

    def handle_ifid_packet(self, pld, meta):
        """
        Update the interface state for the corresponding interface.

        :param pld: The IFIDPayload.
        :type pld: IFIDPayload
        """
        ifid = pld.p.relayIF
        with self.ifid_state_lock:
            if ifid not in self.ifid_state:
                raise SCIONKeyError("Invalid IF %d in IFIDPayload" % ifid)
            br = self.ifid2br[ifid]
            br.interfaces[ifid].to_if_id = pld.p.origIF
            prev_state = self.ifid_state[ifid].update()
            if prev_state == InterfaceState.INACTIVE:
                logging.info("IF %d activated", ifid)
            elif prev_state in [
                    InterfaceState.TIMED_OUT, InterfaceState.REVOKED
            ]:
                logging.info("IF %d came back up.", ifid)
            if not prev_state == InterfaceState.ACTIVE:
                if self.zk.have_lock():
                    # Inform BRs about the interface coming up.
                    state_info = IFStateInfo.from_values(
                        ifid, True, self._get_ht_proof(ifid))
                    pld = IFStatePayload.from_values([state_info])
                    for br in self.topology.border_routers:
                        br_addr, br_port = br.int_addrs[0].public[0]
                        meta = UDPMetadata.from_values(host=br_addr,
                                                       port=br_port)
                        self.send_meta(pld.copy(), meta, (br_addr, br_port))

    def run(self):
        """
        Run an instance of the Beacon Server.
        """
        threading.Thread(target=thread_safety_net,
                         args=(self.worker, ),
                         name="BS.worker",
                         daemon=True).start()
        # https://github.com/netsec-ethz/scion/issues/308:
        threading.Thread(target=thread_safety_net,
                         args=(self._handle_if_timeouts, ),
                         name="BS._handle_if_timeouts",
                         daemon=True).start()
        threading.Thread(target=thread_safety_net,
                         args=(self._create_next_tree, ),
                         name="BS._create_next_tree",
                         daemon=True).start()
        threading.Thread(target=thread_safety_net,
                         args=(self._check_trc_cert_reqs, ),
                         name="Elem.check_trc_cert_reqs",
                         daemon=True).start()
        super().run()

    def _create_next_tree(self):
        last_ttl_window = 0
        while self.run_flag.is_set():
            start = time.time()
            cur_ttl_window = ConnectedHashTree.get_ttl_window()
            time_to_sleep = (ConnectedHashTree.get_time_till_next_ttl() -
                             HASHTREE_UPDATE_WINDOW)
            if cur_ttl_window == last_ttl_window:
                time_to_sleep += HASHTREE_TTL
            if time_to_sleep > 0:
                sleep_interval(start, time_to_sleep, "BS._create_next_tree",
                               self._quiet_startup())

            # at this point, there should be <= HASHTREE_UPDATE_WINDOW
            # seconds left in current ttl
            logging.info("Started computing hashtree for next TTL window (%d)",
                         cur_ttl_window + 2)
            last_ttl_window = ConnectedHashTree.get_ttl_window()

            ht_start = time.time()
            ifs = list(self.ifid2br.keys())
            tree = ConnectedHashTree.get_next_tree(self.addr.isd_as, ifs,
                                                   self.hashtree_gen_key,
                                                   HashType.SHA256)
            ht_end = time.time()
            with self._hash_tree_lock:
                self._next_tree = tree
            logging.info(
                "Finished computing hashtree for TTL window %d in %.3fs" %
                (cur_ttl_window + 2, ht_end - ht_start))

    def _maintain_hash_tree(self):
        """
        Maintain the hashtree. Update the the windows in the connected tree
        """
        with self._hash_tree_lock:
            if self._next_tree is not None:
                self._hash_tree.update(self._next_tree)
                self._next_tree = None
            else:
                logging.critical("Did not create hashtree in time; dying")
                kill_self()
        logging.info("New Hash Tree TTL window beginning: %s",
                     ConnectedHashTree.get_ttl_window())

    def worker(self):
        """
        Worker thread that takes care of reading shared PCBs from ZK, and
        propagating PCBS/registering paths when master.
        """
        last_propagation = last_registration = 0
        last_ttl_window = ConnectedHashTree.get_ttl_window()
        worker_cycle = 1.0
        start = time.time()
        while self.run_flag.is_set():
            sleep_interval(start, worker_cycle, "BS.worker cycle",
                           self._quiet_startup())
            start = time.time()
            try:
                self.zk.wait_connected()
                self.pcb_cache.process()
                self.revobjs_cache.process()
                self.handle_rev_objs()

                cur_ttl_window = ConnectedHashTree.get_ttl_window()
                if cur_ttl_window != last_ttl_window:
                    self._maintain_hash_tree()
                    last_ttl_window = cur_ttl_window

                ret = self.zk.get_lock(lock_timeout=0, conn_timeout=0)
                if not ret:  # Failed to get the lock
                    continue
                elif ret == ZK_LOCK_SUCCESS:
                    logging.info("Became master")
                    self._became_master()
                self.pcb_cache.expire(self.config.propagation_time * 10)
                self.revobjs_cache.expire(self.ZK_REV_OBJ_MAX_AGE)
            except ZkNoConnection:
                continue
            now = time.time()
            if now - last_propagation >= self.config.propagation_time:
                self.handle_pcbs_propagation()
                last_propagation = now
            if (self.config.registers_paths and
                    now - last_registration >= self.config.registration_time):
                try:
                    self.register_segments()
                except SCIONKeyError as e:
                    logging.error("Error while registering segments: %s", e)
                    pass
                last_registration = now

    def _became_master(self):
        """
        Called when a BS becomes the new master. Resets some state that will be
        rebuilt over time.
        """
        # Reset all timed-out and revoked interfaces to inactive.
        with self.ifid_state_lock:
            for (_, ifstate) in self.ifid_state.items():
                if not ifstate.is_active():
                    ifstate.reset()

    def _get_my_trc(self):
        return self.trust_store.get_trc(self.addr.isd_as[0])

    def _get_my_cert(self):
        return self.trust_store.get_cert(self.addr.isd_as)

    @abstractmethod
    def _handle_verified_beacon(self, pcb):
        """
        Once a beacon has been verified, place it into the right containers.

        :param pcb: verified path segment.
        :type pcb: PathSegment
        """
        raise NotImplementedError

    def process_rev_objects(self, rev_infos):
        """
        Processes revocation infos stored in Zookeeper.
        """
        with self._rev_seg_lock:
            for raw in rev_infos:
                try:
                    rev_info = RevocationInfo.from_raw(raw)
                except SCIONParseError as e:
                    logging.error(
                        "Error processing revocation info from ZK: %s", e)
                    continue
                self.local_rev_cache[rev_info] = rev_info.copy()

    def _issue_revocation(self, if_id):
        """
        Store a RevocationInfo in ZK and send a revocation to all BRs.

        :param if_id: The interface that needs to be revoked.
        :type if_id: int
        """
        # Only the master BS issues revocations.
        if not self.zk.have_lock():
            return
        rev_info = self._get_ht_proof(if_id)
        logging.info("Issuing revocation: %s", rev_info.short_desc())
        # Issue revocation to all BRs.
        info = IFStateInfo.from_values(if_id, False, rev_info)
        pld = IFStatePayload.from_values([info])
        for br in self.topology.border_routers:
            br_addr, br_port = br.int_addrs[0].public[0]
            meta = UDPMetadata.from_values(host=br_addr, port=br_port)
            self.send_meta(pld.copy(), meta, (br_addr, br_port))
        self._process_revocation(rev_info)
        self._send_rev_to_local_ps(rev_info)

    def _send_rev_to_local_ps(self, rev_info):
        """
        Sends the given revocation to its local path server.
        :param rev_info: The RevocationInfo object
        :type rev_info: RevocationInfo
        """
        if self.zk.have_lock() and self.topology.path_servers:
            try:
                addr, port = self.dns_query_topo(PATH_SERVICE)[0]
            except SCIONServiceLookupError:
                # If there are no local path servers, stop here.
                return
            meta = UDPMetadata.from_values(host=addr, port=port)
            self.send_meta(rev_info.copy(), meta)

    def _handle_scmp_revocation(self, pld, meta):
        rev_info = RevocationInfo.from_raw(pld.info.rev_info)
        logging.debug("Received revocation via SCMP: %s (from %s)",
                      rev_info.short_desc(), meta)
        self._process_revocation(rev_info)

    def _handle_revocation(self, rev_info, meta):
        logging.debug("Received revocation via TCP/UDP: %s (from %s)",
                      rev_info.short_desc(), meta)
        if not self._validate_revocation(rev_info):
            return
        self._process_revocation(rev_info)

    def handle_rev_objs(self):
        with self._rev_seg_lock:
            for rev_info in self.local_rev_cache.values():
                self._remove_revoked_pcbs(rev_info)

    def _process_revocation(self, rev_info):
        """
        Removes PCBs containing a revoked interface and sends the revocation
        to the local PS.

        :param rev_info: The RevocationInfo object
        :type rev_info: RevocationInfo
        """
        assert isinstance(rev_info, RevocationInfo)
        if_id = rev_info.p.ifID
        if not if_id:
            logging.error("Trying to revoke IF with ID 0.")
            return
        with self._rev_seg_lock:
            self.local_rev_cache[rev_info] = rev_info.copy()
        rev_token = rev_info.copy().pack()
        entry_name = "%s:%s" % (hash(rev_token), time.time())
        try:
            self.revobjs_cache.store(entry_name, rev_token)
        except ZkNoConnection as exc:
            logging.error("Unable to store revocation in shared cache "
                          "(no ZK connection): %s" % exc)
        self._remove_revoked_pcbs(rev_info)

    @abstractmethod
    def _remove_revoked_pcbs(self, rev_info):
        """
        Removes the PCBs containing the revoked interface.

        :param rev_info: The RevocationInfo object.
        :type rev_info: RevocationInfo
        """
        raise NotImplementedError

    def _pcb_list_to_remove(self, candidates, rev_info):
        """
        Calculates the list of PCBs to remove.
        Called by _remove_revoked_pcbs.

        :param candidates: Candidate PCBs.
        :type candidates: List
        :param rev_info: The RevocationInfo object.
        :type rev_info: RevocationInfo
        """
        to_remove = []
        processed = set()
        for cand in candidates:
            if cand.id in processed:
                continue
            processed.add(cand.id)
            if not ConnectedHashTree.verify_epoch(rev_info.p.epoch):
                continue

            # If the interface on which we received the PCB is
            # revoked, then the corresponding pcb needs to be removed.
            root_verify = ConnectedHashTree.verify(rev_info,
                                                   self._get_ht_root())
            if (self.addr.isd_as == rev_info.isd_as()
                    and cand.pcb.p.ifID == rev_info.p.ifID and root_verify):
                to_remove.append(cand.id)

            for asm in cand.pcb.iter_asms():
                if self._verify_revocation_for_asm(rev_info, asm, False):
                    to_remove.append(cand.id)

        return to_remove

    def _handle_if_timeouts(self):
        """
        Periodically checks each interface state and issues an if revocation, if
        no keep-alive message was received for IFID_TOUT.
        """
        if_id_last_revoked = defaultdict(int)
        while self.run_flag.is_set():
            start_time = time.time()
            with self.ifid_state_lock:
                for (if_id, if_state) in self.ifid_state.items():
                    cur_epoch = ConnectedHashTree.get_current_epoch()
                    if not if_state.is_expired() or (
                            if_state.is_revoked()
                            and if_id_last_revoked[if_id] == cur_epoch):
                        # Either the interface hasn't timed out, or it's already revoked for this
                        # epoch
                        continue
                    if_id_last_revoked[if_id] = cur_epoch
                    if not if_state.is_revoked():
                        logging.info("IF %d went down.", if_id)
                    self._issue_revocation(if_id)
                    if_state.revoke_if_expired()
            sleep_interval(start_time, self.IF_TIMEOUT_INTERVAL,
                           "Handle IF timeouts")

    def _handle_ifstate_request(self, req, meta):
        # Only master replies to ifstate requests.
        if not self.zk.have_lock():
            return
        assert isinstance(req, IFStateRequest)
        infos = []
        with self.ifid_state_lock:
            if req.p.ifID == IFStateRequest.ALL_INTERFACES:
                ifid_states = self.ifid_state.items()
            elif req.p.ifID in self.ifid_state:
                ifid_states = [(req.p.ifID, self.ifid_state[req.p.ifID])]
            else:
                logging.error(
                    "Received ifstate request from %s for unknown "
                    "interface %s.", meta, req.p.ifID)
                return

            for (ifid, state) in ifid_states:
                # Don't include inactive interfaces in response.
                if state.is_inactive():
                    continue
                info = IFStateInfo.from_values(ifid, state.is_active(),
                                               self._get_ht_proof(ifid))
                infos.append(info)
        if not infos and not self._quiet_startup():
            logging.warning("No IF state info to put in response. Req: %s" %
                            req.short_desc())
            return
        payload = IFStatePayload.from_values(infos)
        self.send_meta(payload, meta, (meta.host, meta.port))
示例#5
0
文件: base.py 项目: sasjafor/scion
class BeaconServer(SCIONElement, metaclass=ABCMeta):
    """
    The SCION PathConstructionBeacon Server.

    Attributes:
        if2rev_tokens: Contains the currently used revocation token
            hash-chain for each interface.
    """
    SERVICE_TYPE = BEACON_SERVICE
    # Amount of time units a HOF is valid (time unit is EXP_TIME_UNIT).
    HOF_EXP_TIME = 63
    # Timeout for TRC or Certificate requests.
    REQUESTS_TIMEOUT = 10
    # ZK path for incoming PCBs
    ZK_PCB_CACHE_PATH = "pcb_cache"
    # ZK path for revocations.
    ZK_REVOCATIONS_PATH = "rev_cache"
    # Time revocation objects are cached in memory (in seconds).
    ZK_REV_OBJ_MAX_AGE = HASHTREE_EPOCH_TIME
    # Interval to checked for timed out interfaces.
    IF_TIMEOUT_INTERVAL = 1

    def __init__(self, server_id, conf_dir):
        """
        :param str server_id: server identifier.
        :param str conf_dir: configuration directory.
        """
        super().__init__(server_id, conf_dir)
        # TODO: add 2 policies
        self.path_policy = PathPolicy.from_file(
            os.path.join(conf_dir, PATH_POLICY_FILE))
        self.unverified_beacons = deque()
        self.trc_requests = {}
        self.trcs = {}
        sig_key_file = get_sig_key_file_path(self.conf_dir)
        self.signing_key = base64.b64decode(read_file(sig_key_file))
        self.of_gen_key = PBKDF2(self.config.master_as_key, b"Derive OF Key")
        self.hashtree_gen_key = PBKDF2(self.config.master_as_key,
                                       b"Derive hashtree Key")
        logging.info(self.config.__dict__)
        self._hash_tree = None
        self._hash_tree_lock = Lock()
        self._next_tree = None
        self._init_hash_tree()
        self.ifid_state = {}
        for ifid in self.ifid2br:
            self.ifid_state[ifid] = InterfaceState()
        self.ifid_state_lock = RLock()
        self.CTRL_PLD_CLASS_MAP = {
            PayloadClass.PCB: {
                None: self.handle_pcb
            },
            PayloadClass.IFID: {
                None: self.handle_ifid_packet
            },
            PayloadClass.CERT: {
                CertMgmtType.CERT_CHAIN_REPLY: self.process_cert_chain_rep,
                CertMgmtType.TRC_REPLY: self.process_trc_rep,
            },
            PayloadClass.PATH: {
                PMT.IFSTATE_REQ: self._handle_ifstate_request,
                PMT.REVOCATION: self._handle_revocation,
            },
        }
        self.SCMP_PLD_CLASS_MAP = {
            SCMPClass.PATH: {
                SCMPPathClass.REVOKED_IF: self._handle_scmp_revocation,
            },
        }

        zkid = ZkID.from_values(self.addr.isd_as, self.id,
                                [(self.addr.host, self._port)]).pack()
        self.zk = Zookeeper(self.addr.isd_as, BEACON_SERVICE, zkid,
                            self.topology.zookeepers)
        self.zk.retry("Joining party", self.zk.party_setup)
        self.incoming_pcbs = deque()
        self.pcb_cache = ZkSharedCache(self.zk, self.ZK_PCB_CACHE_PATH,
                                       self.process_pcbs)
        self.revobjs_cache = ZkSharedCache(self.zk, self.ZK_REVOCATIONS_PATH,
                                           self.process_rev_objects)
        self.local_rev_cache = ExpiringDict(
            1000, HASHTREE_EPOCH_TIME + HASHTREE_EPOCH_TOLERANCE)
        self.local_rev_cache_lock = Lock()

    def _init_hash_tree(self):
        ifs = list(self.ifid2br.keys())
        self._hash_tree = ConnectedHashTree(self.addr.isd_as, ifs,
                                            self.hashtree_gen_key)

    def _get_ht_proof(self, if_id):
        with self._hash_tree_lock:
            return self._hash_tree.get_proof(if_id)

    def _get_ht_root(self):
        with self._hash_tree_lock:
            return self._hash_tree.get_root()

    def propagate_downstream_pcb(self, pcb):
        """
        Propagates the beacon to all children.

        :param pcb: path segment.
        :type pcb: PathSegment
        """
        for r in self.topology.child_border_routers:
            if not r.interface.to_if_id:
                continue
            new_pcb, meta = self._mk_prop_pcb_meta(pcb.copy(),
                                                   r.interface.isd_as,
                                                   r.interface.if_id)
            if not new_pcb:
                continue
            self.send_meta(new_pcb, meta)
            logging.info("Downstream PCB propagated to %s via IF %s",
                         r.interface.isd_as, r.interface.if_id)

    def _mk_prop_pcb_meta(self, pcb, dst_ia, egress_if):
        ts = pcb.get_timestamp()
        asm = self._create_asm(pcb.p.ifID, egress_if, ts, pcb.last_hof())
        if not asm:
            return None, None
        pcb.add_asm(asm)
        pcb.sign(self.signing_key)
        one_hop_path = self._create_one_hop_path(egress_if)
        if self.DefaultMeta == TCPMetadata:
            return pcb, self.DefaultMeta.from_values(ia=dst_ia,
                                                     host=SVCType.BS_A,
                                                     path=one_hop_path,
                                                     flags=TCPFlags.ONEHOPPATH)
        return pcb, UDPMetadata.from_values(ia=dst_ia,
                                            host=SVCType.BS_A,
                                            path=one_hop_path,
                                            ext_hdrs=[OneHopPathExt()])

    def _create_one_hop_path(self, egress_if):
        ts = int(SCIONTime.get_time())
        info = InfoOpaqueField.from_values(ts, self.addr.isd_as[0], hops=2)
        hf1 = HopOpaqueField.from_values(self.HOF_EXP_TIME, 0, egress_if)
        hf1.set_mac(self.of_gen_key, ts, None)
        # Return a path where second HF is empty.
        return SCIONPath.from_values(info, [hf1, HopOpaqueField()])

    def _mk_if_info(self, if_id):
        """
        Small helper method to make it easier to deal with ingress/egress
        interface being 0 while building ASMarkings.
        """
        d = {"remote_ia": ISD_AS.from_values(0, 0), "remote_if": 0, "mtu": 0}
        if not if_id:
            return d
        br = self.ifid2br[if_id]
        d["remote_ia"] = br.interface.isd_as
        d["remote_if"] = br.interface.to_if_id
        d["mtu"] = br.interface.mtu
        return d

    @abstractmethod
    def handle_pcbs_propagation(self):
        """
        Main loop to propagate received beacons.
        """
        raise NotImplementedError

    def handle_pcb(self, pcb, meta):
        """Receives beacon and stores it for processing."""
        pcb.p.ifID = meta.path.get_hof().ingress_if
        if not self.path_policy.check_filters(pcb):
            return
        self.incoming_pcbs.append(pcb)
        meta.close()
        entry_name = "%s-%s" % (pcb.get_hops_hash(hex=True), time.time())
        try:
            self.pcb_cache.store(entry_name, pcb.copy().pack())
        except ZkNoConnection:
            logging.error("Unable to store PCB in shared cache: "
                          "no connection to ZK")

    def handle_ext(self, pcb):
        """
        Handle beacon extensions.
        """
        # Handle PCB extensions:
        if pcb.is_sibra():
            logging.debug("%s", pcb.sibra_ext)

    @abstractmethod
    def process_pcbs(self, pcbs, raw=True):
        """
        Processes new beacons and appends them to beacon list.
        """
        raise NotImplementedError

    def process_pcb_queue(self):
        pcbs = []
        while self.incoming_pcbs:
            pcbs.append(self.incoming_pcbs.popleft())
        self.process_pcbs(pcbs, raw=False)
        logging.debug("Processed %d pcbs from incoming queue", len(pcbs))

    @abstractmethod
    def register_segments(self):
        """
        Registers paths according to the received beacons.
        """
        raise NotImplementedError

    def _create_asm(self, in_if, out_if, ts, prev_hof):
        pcbms = list(self._create_pcbms(in_if, out_if, ts, prev_hof))
        if not pcbms:
            return None
        chain = self._get_my_cert()
        _, cert_ver = chain.get_leaf_isd_as_ver()
        return ASMarking.from_values(self.addr.isd_as,
                                     self._get_my_trc().version, cert_ver,
                                     pcbms, self._get_ht_root(),
                                     self.topology.mtu, chain)

    def _create_pcbms(self, in_if, out_if, ts, prev_hof):
        up_pcbm = self._create_pcbm(in_if, out_if, ts, prev_hof)
        if not up_pcbm:
            return
        yield up_pcbm
        for br in sorted(self.topology.peer_border_routers):
            in_if = br.interface.if_id
            with self.ifid_state_lock:
                if (not self.ifid_state[in_if].is_active()
                        and not self._quiet_startup()):
                    logging.warning('Peer ifid:%d inactive (not added).',
                                    in_if)
                    continue
            peer_pcbm = self._create_pcbm(in_if,
                                          out_if,
                                          ts,
                                          up_pcbm.hof(),
                                          xover=True)
            if peer_pcbm:
                yield peer_pcbm

    def _create_pcbm(self, in_if, out_if, ts, prev_hof, xover=False):
        in_info = self._mk_if_info(in_if)
        if in_info["remote_ia"].int() and not in_info["remote_if"]:
            return None
        out_info = self._mk_if_info(out_if)
        if out_info["remote_ia"].int() and not out_info["remote_if"]:
            return None
        hof = HopOpaqueField.from_values(self.HOF_EXP_TIME,
                                         in_if,
                                         out_if,
                                         xover=xover)
        hof.set_mac(self.of_gen_key, ts, prev_hof)
        return PCBMarking.from_values(in_info["remote_ia"],
                                      in_info["remote_if"], in_info["mtu"],
                                      out_info["remote_ia"],
                                      out_info["remote_if"], hof)

    def _terminate_pcb(self, pcb):
        """
        Copies a PCB, terminates it and adds the segment ID.

        Terminating a PCB means adding a opaque field with the egress IF set
        to 0, i.e., there is no AS to forward a packet containing this path
        segment to.
        """
        pcb = pcb.copy()
        asm = self._create_asm(pcb.p.ifID, 0, pcb.get_timestamp(),
                               pcb.last_hof())
        if not asm:
            return None
        pcb.add_asm(asm)
        return pcb

    def handle_ifid_packet(self, pld, meta):
        """
        Update the interface state for the corresponding interface.

        :param pld: The IFIDPayload.
        :type pld: IFIDPayload
        """
        ifid = pld.p.relayIF
        with self.ifid_state_lock:
            if ifid not in self.ifid_state:
                raise SCIONKeyError("Invalid IF %d in IFIDPayload" % ifid)
            br = self.ifid2br[ifid]
            br.interface.to_if_id = pld.p.origIF
            prev_state = self.ifid_state[ifid].update()
            if prev_state == InterfaceState.INACTIVE:
                logging.info("IF %d activated", ifid)
            elif prev_state in [
                    InterfaceState.TIMED_OUT, InterfaceState.REVOKED
            ]:
                logging.info("IF %d came back up.", ifid)
            if not prev_state == InterfaceState.ACTIVE:
                if self.zk.have_lock():
                    # Inform BRs about the interface coming up.
                    state_info = IFStateInfo.from_values(
                        ifid, True, self._get_ht_proof(ifid))
                    pld = IFStatePayload.from_values([state_info])
                    for br in self.topology.get_all_border_routers():
                        meta = UDPMetadata.from_values(host=br.addr,
                                                       port=br.port)
                        self.send_meta(pld.copy(), meta, (br.addr, br.port))

    def run(self):
        """
        Run an instance of the Beacon Server.
        """
        threading.Thread(target=thread_safety_net,
                         args=(self.worker, ),
                         name="BS.worker",
                         daemon=True).start()
        # https://github.com/netsec-ethz/scion/issues/308:
        threading.Thread(target=thread_safety_net,
                         args=(self._handle_if_timeouts, ),
                         name="BS._handle_if_timeouts",
                         daemon=True).start()
        threading.Thread(target=thread_safety_net,
                         args=(self._create_next_tree, ),
                         name="BS._create_next_tree",
                         daemon=True).start()
        super().run()

    def _create_next_tree(self):
        last_ttl_window = 0
        while self.run_flag.is_set():
            start = time.time()
            cur_ttl_window = ConnectedHashTree.get_ttl_window()
            time_to_sleep = (ConnectedHashTree.get_time_till_next_ttl() -
                             HASHTREE_UPDATE_WINDOW)
            if cur_ttl_window == last_ttl_window:
                time_to_sleep += HASHTREE_TTL
            if time_to_sleep > 0:
                sleep_interval(start, time_to_sleep, "BS._create_next_tree",
                               self._quiet_startup())

            # at this point, there should be <= HASHTREE_UPDATE_WINDOW
            # seconds left in current ttl
            logging.info("Started computing hashtree for next ttl")
            last_ttl_window = ConnectedHashTree.get_ttl_window()

            ifs = list(self.ifid2br.keys())
            tree = ConnectedHashTree.get_next_tree(self.addr.isd_as, ifs,
                                                   self.hashtree_gen_key)
            with self._hash_tree_lock:
                self._next_tree = tree

    def _maintain_hash_tree(self):
        """
        Maintain the hashtree. Update the the windows in the connected tree
        """
        with self._hash_tree_lock:
            if self._next_tree is not None:
                self._hash_tree.update(self._next_tree)
                self._next_tree = None
            else:
                logging.critical("Did not create hashtree in time; dying")
                kill_self()
        logging.info("New Hash Tree TTL beginning")

    def worker(self):
        """
        Worker thread that takes care of reading shared PCBs from ZK, and
        propagating PCBS/registering paths when master.
        """
        last_propagation = last_registration = 0
        last_ttl_window = ConnectedHashTree.get_ttl_window()
        worker_cycle = 1.0
        was_master = False
        start = time.time()
        while self.run_flag.is_set():
            sleep_interval(start, worker_cycle, "BS.worker cycle",
                           self._quiet_startup())
            start = time.time()
            try:
                self.process_pcb_queue()
                self.handle_unverified_beacons()
                self.zk.wait_connected()
                self.pcb_cache.process()
                self.revobjs_cache.process()
                self.handle_rev_objs()

                cur_ttl_window = ConnectedHashTree.get_ttl_window()
                if cur_ttl_window != last_ttl_window:
                    self._maintain_hash_tree()
                    last_ttl_window = cur_ttl_window

                if not self.zk.get_lock(lock_timeout=0, conn_timeout=0):
                    was_master = False
                    continue

                if not was_master:
                    self._became_master()
                    was_master = True
                self.pcb_cache.expire(self.config.propagation_time * 10)
                self.revobjs_cache.expire(self.ZK_REV_OBJ_MAX_AGE)
            except ZkNoConnection:
                continue
            now = time.time()
            if now - last_propagation >= self.config.propagation_time:
                self.handle_pcbs_propagation()
                last_propagation = now
            if (self.config.registers_paths and
                    now - last_registration >= self.config.registration_time):
                try:
                    self.register_segments()
                except SCIONKeyError as e:
                    logging.error("Register_segments: %s", e)
                    pass
                last_registration = now

    def _became_master(self):
        """
        Called when a BS becomes the new master. Resets some state that will be
        rebuilt over time.
        """
        # Reset all timed-out and revoked interfaces to inactive.
        with self.ifid_state_lock:
            for (_, ifstate) in self.ifid_state.items():
                if not ifstate.is_active():
                    ifstate.reset()

    def _try_to_verify_beacon(self, pcb, quiet=False):
        """
        Try to verify a beacon.

        :param pcb: path segment to verify.
        :type pcb: PathSegment
        """
        assert isinstance(pcb, PathSegment)
        asm = pcb.asm(-1)
        if self._check_trc(asm.isd_as(), asm.p.trcVer):
            if self._verify_beacon(pcb):
                self._handle_verified_beacon(pcb)
            else:
                logging.warning("Invalid beacon. %s", pcb)
        else:
            if not quiet:
                logging.warning("Certificate(s) or TRC missing for pcb: %s",
                                pcb.short_desc())
            self.unverified_beacons.append(pcb)

    @abstractmethod
    def _check_trc(self, isd_as, trc_ver):
        """
        Return True or False whether the necessary Certificate and TRC files are
        found.

        :param ISD_AS isd_is: ISD-AS identifier.
        :param int trc_ver: TRC file version.
        """
        raise NotImplementedError

    def _get_my_trc(self):
        return self.trust_store.get_trc(self.addr.isd_as[0])

    def _get_my_cert(self):
        return self.trust_store.get_cert(self.addr.isd_as)

    def _get_trc(self, isd_as, trc_ver):
        """
        Get TRC from local storage or memory.

        :param ISD_AS isd_as: ISD-AS identifier.
        :param int trc_ver: TRC file version.
        """
        trc = self.trust_store.get_trc(isd_as[0], trc_ver)
        if not trc:
            # Requesting TRC file from cert server
            trc_tuple = isd_as[0], trc_ver
            now = int(time.time())
            if (trc_tuple not in self.trc_requests or
                (now - self.trc_requests[trc_tuple] > self.REQUESTS_TIMEOUT)):
                trc_req = TRCRequest.from_values(isd_as, trc_ver)
                logging.info("Requesting %sv%s TRC", isd_as[0], trc_ver)
                try:
                    addr, port = self.dns_query_topo(CERTIFICATE_SERVICE)[0]
                except SCIONServiceLookupError as e:
                    logging.warning("Sending TRC request failed: %s", e)
                    return None
                meta = UDPMetadata.from_values(host=addr, port=port)
                self.send_meta(trc_req, meta)
                self.trc_requests[trc_tuple] = now
                return None
        return trc

    def _verify_beacon(self, pcb):
        """
        Once the necessary certificate and TRC files have been found, verify the
        beacons.

        :param pcb: path segment to verify.
        :type pcb: PathSegment
        """
        assert isinstance(pcb, PathSegment)
        asm = pcb.asm(-1)
        cert_ia = asm.isd_as()
        trc = self.trust_store.get_trc(cert_ia[0], asm.p.trcVer)
        return verify_sig_chain_trc(pcb.sig_pack(), asm.p.sig, str(cert_ia),
                                    asm.chain(), trc, asm.p.trcVer)

    @abstractmethod
    def _handle_verified_beacon(self, pcb):
        """
        Once a beacon has been verified, place it into the right containers.

        :param pcb: verified path segment.
        :type pcb: PathSegment
        """
        raise NotImplementedError

    @abstractmethod
    def process_cert_chain_rep(self, cert_chain_rep, meta):
        """
        Process the Certificate chain reply.
        """
        raise NotImplementedError

    def process_trc_rep(self, rep, meta):
        """
        Process the TRC reply.

        :param rep: TRC reply.
        :type rep: TRCReply
        """
        logging.info("TRC reply received for %s", rep.trc.get_isd_ver())
        self.trust_store.add_trc(rep.trc)

        rep_key = rep.trc.get_isd_ver()
        if rep_key in self.trc_requests:
            del self.trc_requests[rep_key]

    def handle_unverified_beacons(self):
        """
        Handle beacons which are waiting to be verified.
        """
        for _ in range(len(self.unverified_beacons)):
            pcb = self.unverified_beacons.popleft()
            self._try_to_verify_beacon(pcb, quiet=True)

    def process_rev_objects(self, rev_infos):
        """
        Processes revocation infos stored in Zookeeper.
        """
        with self.local_rev_cache_lock:
            for raw in rev_infos:
                try:
                    rev_info = RevocationInfo.from_raw(raw)
                except SCIONParseError as e:
                    logging.error(
                        "Error processing revocation info from ZK: %s", e)
                    continue
                self.local_rev_cache[rev_info] = rev_info.copy()

    def _issue_revocation(self, if_id):
        """
        Store a RevocationInfo in ZK and send a revocation to all BRs.

        :param if_id: The interface that needs to be revoked.
        :type if_id: int
        """
        # Only the master BS issues revocations.
        if not self.zk.have_lock():
            return
        rev_info = self._get_ht_proof(if_id)
        logging.error("Issuing revocation for IF %d.", if_id)
        # Issue revocation to all BRs.
        info = IFStateInfo.from_values(if_id, False, rev_info)
        pld = IFStatePayload.from_values([info])
        for br in self.topology.get_all_border_routers():
            meta = UDPMetadata.from_values(host=br.addr, port=br.port)
            self.send_meta(pld.copy(), meta, (br.addr, br.port))
        self._process_revocation(rev_info)
        self._send_rev_to_local_ps(rev_info)

    def _send_rev_to_local_ps(self, rev_info):
        """
        Sends the given revocation to its local path server.
        :param rev_info: The RevocationInfo object
        :type rev_info: RevocationInfo
        """
        if self.zk.have_lock() and self.topology.path_servers:
            try:
                addr, port = self.dns_query_topo(PATH_SERVICE)[0]
            except SCIONServiceLookupError:
                # If there are no local path servers, stop here.
                return
            logging.info("Sending revocation to local PS.")
            meta = UDPMetadata.from_values(host=addr, port=port)
            self.send_meta(rev_info.copy(), meta)

    def _handle_scmp_revocation(self, pld, meta):
        rev_info = RevocationInfo.from_raw(pld.info.rev_info)
        logging.info("Received revocation via SCMP:\n%s",
                     rev_info.short_desc())
        self._process_revocation(rev_info)

    def _handle_revocation(self, rev_info, meta):
        logging.info("Received revocation via TCP/UDP:\n%s",
                     rev_info.short_desc())
        if not self._validate_revocation(rev_info):
            return
        self._process_revocation(rev_info)

    def handle_rev_objs(self):
        with self.local_rev_cache_lock:
            for rev_info in self.local_rev_cache.values():
                self._remove_revoked_pcbs(rev_info)

    def _process_revocation(self, rev_info):
        """
        Removes PCBs containing a revoked interface and sends the revocation
        to the local PS.

        :param rev_info: The RevocationInfo object
        :type rev_info: RevocationInfo
        """
        assert isinstance(rev_info, RevocationInfo)
        if_id = rev_info.p.ifID
        if not if_id:
            logging.error("Trying to revoke IF with ID 0.")
            return

        with self.local_rev_cache_lock:
            self.local_rev_cache[rev_info] = rev_info.copy()

        logging.info("Storing revocation in ZK.")
        rev_token = rev_info.copy().pack()
        entry_name = "%s:%s" % (hash(rev_token), time.time())
        try:
            self.revobjs_cache.store(entry_name, rev_token)
        except ZkNoConnection as exc:
            logging.error("Unable to store revocation in shared cache "
                          "(no ZK connection): %s" % exc)
        self._remove_revoked_pcbs(rev_info)

    @abstractmethod
    def _remove_revoked_pcbs(self, rev_info):
        """
        Removes the PCBs containing the revoked interface.

        :param rev_info: The RevocationInfo object.
        :type rev_info: RevocationInfo
        """
        raise NotImplementedError

    def _pcb_list_to_remove(self, candidates, rev_info):
        """
        Calculates the list of PCBs to remove.
        Called by _remove_revoked_pcbs.

        :param candidates: Candidate PCBs.
        :type candidates: List
        :param rev_info: The RevocationInfo object.
        :type rev_info: RevocationInfo
        """
        to_remove = []
        processed = set()
        for cand in candidates:
            if cand.id in processed:
                continue
            processed.add(cand.id)
            if not ConnectedHashTree.verify_epoch(rev_info.p.epoch):
                continue

            # If the interface on which we received the PCB is
            # revoked, then the corresponding pcb needs to be removed, if
            # the proof can be verified with the own AS's root for the current
            # epoch and  the if_id of the interface on which pcb was received
            # matches that in the rev_info
            root_verify = ConnectedHashTree.verify(rev_info,
                                                   self._get_ht_root())
            if (self.addr.isd_as == rev_info.isd_as()
                    and cand.pcb.p.ifID == rev_info.p.ifID and root_verify):
                to_remove.append(cand.id)

            for asm in cand.pcb.iter_asms():
                if self._verify_revocation_for_asm(rev_info, asm, False):
                    to_remove.append(cand.id)

        return to_remove

    def _handle_if_timeouts(self):
        """
        Periodically checks each interface state and issues an if revocation, if
        no keep-alive message was received for IFID_TOUT.
        """
        if_id_last_revoked = defaultdict(int)
        while self.run_flag.is_set():
            start_time = time.time()
            with self.ifid_state_lock:
                for (if_id, if_state) in self.ifid_state.items():
                    cur_epoch = ConnectedHashTree.get_current_epoch()
                    # Check if interface has timed-out.
                    if ((if_state.is_expired() or if_state.is_revoked())
                            and (if_id_last_revoked[if_id] != cur_epoch)):
                        if_id_last_revoked[if_id] = cur_epoch
                        if not if_state.is_revoked():
                            logging.info("IF %d appears to be down.", if_id)
                        self._issue_revocation(if_id)
                        if_state.revoke_if_expired()
            sleep_interval(start_time, self.IF_TIMEOUT_INTERVAL,
                           "Handle IF timeouts")

    def _handle_ifstate_request(self, req, meta):
        # Only master replies to ifstate requests.
        if not self.zk.have_lock():
            return
        assert isinstance(req, IFStateRequest)
        logging.debug("Received ifstate req:\n%s", req)
        infos = []
        with self.ifid_state_lock:
            if req.p.ifID == IFStateRequest.ALL_INTERFACES:
                ifid_states = self.ifid_state.items()
            elif req.p.ifID in self.ifid_state:
                ifid_states = [(req.p.ifID, self.ifid_state[req.p.ifID])]
            else:
                logging.error(
                    "Received ifstate request from %s for unknown "
                    "interface %s.", meta.get_addr(), req.p.ifID)
                return

            for (ifid, state) in ifid_states:
                # Don't include inactive interfaces in response.
                if state.is_inactive():
                    continue
                info = IFStateInfo.from_values(ifid, state.is_active(),
                                               self._get_ht_proof(ifid))
                infos.append(info)
        if not infos and not self._quiet_startup():
            logging.warning("No IF state info to put in response.")
            return
        payload = IFStatePayload.from_values(infos)
        self.send_meta(payload, meta, (meta.host, meta.port))