Ejemplo n.º 1
0
 def __init__(self, server_id, conf_dir):
     """
     :param str server_id: server identifier.
     :param str conf_dir: configuration directory.
     """
     super().__init__(server_id, conf_dir)
     self.down_segments = PathSegmentDB(max_res_no=self.MAX_SEG_NO)
     self.core_segments = PathSegmentDB(max_res_no=self.MAX_SEG_NO)
     self.pending_req = defaultdict(list)  # Dict of pending requests.
     # Used when l/cPS doesn't have up/dw-path.
     self.waiting_targets = defaultdict(list)
     self.revocations = ExpiringDict(1000, 300)
     self.iftoken2seg = defaultdict(set)
     self.CTRL_PLD_CLASS_MAP = {
         PayloadClass.PATH: {
             PMT.REQUEST: self.path_resolution,
             PMT.REPLY: self.handle_path_segment_record,
             PMT.REG: self.handle_path_segment_record,
             PMT.REVOCATION: self._handle_revocation,
             PMT.SYNC: self.handle_path_segment_record,
         },
     }
     self._segs_to_zk = deque()
     # Add more IPs here if we support dual-stack
     name_addrs = "\0".join(
         [self.id, str(SCION_UDP_PORT),
          str(self.addr.host)])
     self.zk = Zookeeper(self.topology.isd_as, PATH_SERVICE, name_addrs,
                         self.topology.zookeepers)
     self.zk.retry("Joining party", self.zk.party_setup)
     self.path_cache = ZkSharedCache(self.zk, self.ZK_PATH_CACHE_PATH,
                                     self._cached_entries_handler)
Ejemplo n.º 2
0
 def __init__(self, server_id, conf_dir, prom_export=None):
     """
     :param str server_id: server identifier.
     :param str conf_dir: configuration directory.
     :param str prom_export: prometheus export address.
     """
     super().__init__(server_id, conf_dir, prom_export=prom_export)
     self.sendq = Queue()
     self.signing_key = get_sig_key(self.conf_dir)
     self.segments = PathSegmentDB(max_res_no=1)
     # Maps of {ISD-AS: {steady path id: steady path}} for all incoming
     # (srcs) and outgoing (dests) steady paths:
     self.srcs = {}
     self.dests = {}
     # Map of SibraState objects by interface ID
     self.link_states = {}
     # Map of link types by interface ID
     self.link_types = {}
     self.lock = threading.Lock()
     self.CTRL_PLD_CLASS_MAP = {
         PayloadClass.PATH: {
             PMT.REG: self.handle_path_reg
         },
         PayloadClass.SIBRA: {
             PayloadClass.SIBRA: self.handle_sibra_pkt
         },
     }
     self._find_links()
     zkid = ZkID.from_values(self.addr.isd_as, self.id,
                             [(self.addr.host, self._port)]).pack()
     self.zk = Zookeeper(self.addr.isd_as, self.SERVICE_TYPE, zkid,
                         self.topology.zookeepers)
     self.zk.retry("Joining party", self.zk.party_setup)
Ejemplo n.º 3
0
 def __init__(self, server_id, conf_dir):
     """
     :param str server_id: server identifier.
     :param str conf_dir: configuration directory.
     """
     super().__init__(server_id, conf_dir)
     self.sendq = Queue()
     sig_key_file = get_sig_key_file_path(self.conf_dir)
     self.signing_key = base64.b64decode(read_file(sig_key_file))
     self.segments = PathSegmentDB(max_res_no=1)
     # Maps of {ISD-AS: {steady path id: steady path}} for all incoming
     # (srcs) and outgoing (dests) steady paths:
     self.srcs = {}
     self.dests = {}
     # Map of SibraState objects by interface ID
     self.link_states = {}
     # Map of link types by interface ID
     self.link_types = {}
     self.lock = threading.Lock()
     self.CTRL_PLD_CLASS_MAP = {
         PayloadClass.PATH: {
             PMT.REG: self.handle_path_reg,
         },
         PayloadClass.SIBRA: {
             SIBRAPayloadType.EMPTY: self.handle_sibra_pkt
         },
     }
     self._find_links()
     name_addrs = "\0".join(
         [self.id, str(SCION_UDP_PORT),
          str(self.addr.host)])
     self.zk = Zookeeper(self.addr.isd_as, SIBRA_SERVICE, name_addrs,
                         self.topology.zookeepers)
     self.zk.retry("Joining party", self.zk.party_setup)
Ejemplo n.º 4
0
 def test_basic(self):
     pth_seg_db = PathSegmentDB()
     pth_seg_db._db = create_mock(['delete'])
     pth_seg_db._db.return_value = "data1"
     ntools.eq_(pth_seg_db.delete("data2"), DBResult.ENTRY_DELETED)
     pth_seg_db._db.assert_called_once_with(id="data2")
     pth_seg_db._db.delete.assert_called_once_with("data1")
Ejemplo n.º 5
0
 def __init__(self, server_id, conf_dir, prom_export=None):
     """
     :param str server_id: server identifier.
     :param str conf_dir: configuration directory.
     :param str prom_export: prometheus export address.
     """
     super().__init__(server_id, conf_dir, prom_export=prom_export)
     down_labels = {
         **self._labels, "type": "down"
     } if self._labels else None
     core_labels = {
         **self._labels, "type": "core"
     } if self._labels else None
     self.down_segments = PathSegmentDB(max_res_no=self.MAX_SEG_NO,
                                        labels=down_labels)
     self.core_segments = PathSegmentDB(max_res_no=self.MAX_SEG_NO,
                                        labels=core_labels)
     self.pending_req = defaultdict(list)  # Dict of pending requests.
     self.pen_req_lock = threading.Lock()
     self._request_logger = None
     # Used when l/cPS doesn't have up/dw-path.
     self.waiting_targets = defaultdict(list)
     self.revocations = RevCache(labels=self._labels)
     # A mapping from (hash tree root of AS, IFID) to segments
     self.htroot_if2seg = ExpiringDict(1000, HASHTREE_TTL)
     self.htroot_if2seglock = Lock()
     self.CTRL_PLD_CLASS_MAP = {
         PayloadClass.PATH: {
             PMT.REQUEST: self.path_resolution,
             PMT.REPLY: self.handle_path_segment_record,
             PMT.REG: self.handle_path_segment_record,
             PMT.REVOCATION: self._handle_revocation,
             PMT.SYNC: self.handle_path_segment_record,
         },
         PayloadClass.CERT: {
             CertMgmtType.CERT_CHAIN_REQ: self.process_cert_chain_request,
             CertMgmtType.CERT_CHAIN_REPLY: self.process_cert_chain_reply,
             CertMgmtType.TRC_REPLY: self.process_trc_reply,
             CertMgmtType.TRC_REQ: self.process_trc_request,
         },
     }
     self.SCMP_PLD_CLASS_MAP = {
         SCMPClass.PATH: {
             SCMPPathClass.REVOKED_IF: self._handle_scmp_revocation,
         },
     }
     self._segs_to_zk = ExpiringDict(1000, self.SEGS_TO_ZK_TTL)
     self._revs_to_zk = ExpiringDict(1000, HASHTREE_EPOCH_TIME)
     self._zkid = ZkID.from_values(self.addr.isd_as, self.id,
                                   [(self.addr.host, self._port)])
     self.zk = Zookeeper(self.topology.isd_as, PATH_SERVICE,
                         self._zkid.copy().pack(), self.topology.zookeepers)
     self.zk.retry("Joining party", self.zk.party_setup)
     self.path_cache = ZkSharedCache(self.zk, self.ZK_PATH_CACHE_PATH,
                                     self._handle_paths_from_zk)
     self.rev_cache = ZkSharedCache(self.zk, self.ZK_REV_CACHE_PATH,
                                    self._rev_entries_handler)
     self._init_request_logger()
Ejemplo n.º 6
0
 def test(self):
     inst = PathSegmentDB()
     inst.delete = create_mock()
     inst.delete.side_effect = (DBResult.ENTRY_DELETED, DBResult.NONE,
                                DBResult.ENTRY_DELETED)
     # Call
     ntools.eq_(inst.delete_all((0, 1, 2)), 2)
     # Tests
     assert_these_calls(inst.delete, [call(i) for i in (0, 1, 2)])
Ejemplo n.º 7
0
 def __init__(self, server_id, conf_dir):
     """
     :param str server_id: server identifier.
     :param str conf_dir: configuration directory.
     """
     super().__init__(server_id, conf_dir)
     # Sanity check that we should indeed be a local path server.
     assert not self.topology.is_core_as, "This shouldn't be a core PS!"
     # Database of up-segments to the core.
     self.up_segments = PathSegmentDB(max_res_no=self.MAX_SEG_NO)
Ejemplo n.º 8
0
 def test_outdated(self, db_rec):
     inst = PathSegmentDB()
     pcb = self._mk_pcb(-1)
     cur_rec = create_mock_full({"pcb": self._mk_pcb(0)})
     inst._db = create_mock_full(return_value={0: {'record': cur_rec}})
     record = create_mock_full({'id': "idstr"})
     db_rec.return_value = record
     # Call
     ntools.eq_(inst.update(pcb), DBResult.NONE)
     # Tests
     pcb.get_expiration_time.assert_called_once_with()
     cur_rec.pcb.get_expiration_time.assert_called_once_with()
Ejemplo n.º 9
0
 def test_add(self, db_rec):
     inst = PathSegmentDB()
     inst._db = create_mock_full({'insert()': None}, return_value=[])
     pcb = self._mk_pcb()
     record = create_mock_full({'id': "id str"})
     db_rec.return_value = record
     # Call
     ntools.eq_(inst.update(pcb), DBResult.ENTRY_ADDED)
     # Tests
     db_rec.assert_called_once_with(pcb)
     inst._db.assert_called_once_with(id="id str", sibra=True)
     inst._db.insert.assert_called_once_with(record, "id str", 1, 2, 3, 4,
                                             True)
Ejemplo n.º 10
0
 def test(self, time):
     inst = PathSegmentDB()
     inst._db = create_mock(['delete'])
     recs = []
     for i in range(5):
         rec = create_mock(['exp_time', 'pcb'])
         rec.exp_time = i
         rec.pcb = create_mock(["short_desc"])
         recs.append({'record': rec})
     time.return_value = 2
     # Call
     ntools.eq_(inst._exp_call_records(recs), recs[2:])
     # Tests
     inst._db.delete.assert_called_once_with(recs[:2])
Ejemplo n.º 11
0
 def test_with_segment_ttl(self, db_rec, time):
     segment_ttl = 300
     inst = PathSegmentDB(segment_ttl)
     cur_rec = create_mock(['pcb', 'id', 'exp_time'])
     cur_rec.pcb = self._mk_pcb(0)
     cur_rec.exp_time = 10
     inst._db = create_mock_full(return_value={0: {'record': cur_rec}})
     pcb = self._mk_pcb(1)
     db_rec.return_value = create_mock(['id'])
     time.return_value = 1
     # Call
     inst.update(pcb)
     # Tests
     db_rec.assert_called_once_with(pcb, segment_ttl + time.return_value)
     ntools.eq_(cur_rec.exp_time, 301)
Ejemplo n.º 12
0
 def test_update(self, db_rec):
     inst = PathSegmentDB()
     pcb = self._mk_pcb(1)
     cur_rec = create_mock_full({
         "pcb": self._mk_pcb(0),
         "id": "cur rec",
         "exp_time": 44
     })
     inst._db = create_mock_full(return_value={0: {'record': cur_rec}})
     db_rec.return_value = create_mock_full({
         'id': "record",
         'exp_time': 32
     })
     # Call
     ntools.eq_(inst.update(pcb), DBResult.ENTRY_UPDATED)
     # Tests
     ntools.eq_(cur_rec.pcb, pcb)
Ejemplo n.º 13
0
 def __init__(self, server_id, conf_dir):
     """
     :param str server_id: server identifier.
     :param str conf_dir: configuration directory.
     """
     super().__init__(server_id, conf_dir)
     self.down_segments = PathSegmentDB(max_res_no=self.MAX_SEG_NO)
     self.core_segments = PathSegmentDB(max_res_no=self.MAX_SEG_NO)
     self.pending_req = defaultdict(list)  # Dict of pending requests.
     # Used when l/cPS doesn't have up/dw-path.
     self.waiting_targets = defaultdict(list)
     self.revocations = ExpiringDict(1000, HASHTREE_EPOCH_TIME)
     # Contains PCBs that include revocations.
     self.pcb_cache = ExpiringDict(100, HASHTREE_EPOCH_TIME)
     self.pcb_cache_lock = Lock()
     # A mapping from (hash tree root of AS, IFID) to segments
     self.htroot_if2seg = ExpiringDict(1000, HASHTREE_TTL)
     self.htroot_if2seglock = Lock()
     self.CTRL_PLD_CLASS_MAP = {
         PayloadClass.PATH: {
             PMT.REQUEST: self.path_resolution,
             PMT.REPLY: self.handle_path_segment_record,
             PMT.REG: self.handle_path_segment_record,
             PMT.REVOCATION: self._handle_revocation,
             PMT.SYNC: self.handle_path_segment_record,
         },
     }
     self.SCMP_PLD_CLASS_MAP = {
         SCMPClass.PATH: {
             SCMPPathClass.REVOKED_IF: self._handle_scmp_revocation,
         },
     }
     self._segs_to_zk = deque()
     self._revs_to_zk = deque()
     self._zkid = ZkID.from_values(self.addr.isd_as, self.id,
                                   [(self.addr.host, self._port)])
     self.zk = Zookeeper(self.topology.isd_as, PATH_SERVICE,
                         self._zkid.copy().pack(), self.topology.zookeepers)
     self.zk.retry("Joining party", self.zk.party_setup)
     self.path_cache = ZkSharedCache(self.zk, self.ZK_PATH_CACHE_PATH,
                                     self._cached_entries_handler)
     self.rev_cache = ZkSharedCache(self.zk, self.ZK_REV_CACHE_PATH,
                                    self._rev_entries_handler)
Ejemplo n.º 14
0
 def __init__(self, server_id, conf_dir, prom_export=None):
     """
     :param str server_id: server identifier.
     :param str conf_dir: configuration directory.
     :param str prom_export: prometheus export address.
     """
     super().__init__(server_id, conf_dir, prom_export)
     # Sanity check that we should indeed be a local path server.
     assert not self.topology.is_core_as, "This shouldn't be a core PS!"
     # Database of up-segments to the core.
     up_labels = {**self._labels, "type": "up"} if self._labels else None
     self.up_segments = PathSegmentDB(max_res_no=self.MAX_SEG_NO, labels=up_labels)
Ejemplo n.º 15
0
    def __init__(self,
                 conf_dir,
                 addr,
                 api_addr,
                 run_local_api=False,
                 port=None,
                 prom_export=None):
        """
        Initialize an instance of the class SCIONDaemon.
        """
        super().__init__("sciond",
                         conf_dir,
                         prom_export=prom_export,
                         public=[(addr, port)])
        up_labels = {**self._labels, "type": "up"} if self._labels else None
        down_labels = {
            **self._labels, "type": "down"
        } if self._labels else None
        core_labels = {
            **self._labels, "type": "core"
        } if self._labels else None
        self.up_segments = PathSegmentDB(segment_ttl=self.SEGMENT_TTL,
                                         labels=up_labels)
        self.down_segments = PathSegmentDB(segment_ttl=self.SEGMENT_TTL,
                                           labels=down_labels)
        self.core_segments = PathSegmentDB(segment_ttl=self.SEGMENT_TTL,
                                           labels=core_labels)
        self.peer_revs = RevCache()
        # Keep track of requested paths.
        self.requested_paths = ExpiringDict(self.MAX_REQS, self.PATH_REQ_TOUT)
        self.req_path_lock = threading.Lock()
        self._api_sock = None
        self.daemon_thread = None
        os.makedirs(SCIOND_API_SOCKDIR, exist_ok=True)
        self.api_addr = (api_addr or os.path.join(
            SCIOND_API_SOCKDIR, "%s.sock" % self.addr.isd_as))

        self.CTRL_PLD_CLASS_MAP = {
            PayloadClass.PATH: {
                PMT.REPLY: self.handle_path_reply,
                PMT.REVOCATION: self.handle_revocation,
            },
            PayloadClass.CERT: {
                CertMgmtType.CERT_CHAIN_REQ: self.process_cert_chain_request,
                CertMgmtType.CERT_CHAIN_REPLY: self.process_cert_chain_reply,
                CertMgmtType.TRC_REPLY: self.process_trc_reply,
                CertMgmtType.TRC_REQ: self.process_trc_request,
            },
        }

        self.SCMP_PLD_CLASS_MAP = {
            SCMPClass.PATH: {
                SCMPPathClass.REVOKED_IF: self.handle_scmp_revocation
            },
        }

        if run_local_api:
            self._api_sock = ReliableSocket(bind_unix=(self.api_addr,
                                                       "sciond"))
            self._socks.add(self._api_sock, self.handle_accept)
Ejemplo n.º 16
0
    def __init__(self,
                 conf_dir,
                 addr,
                 api_addr,
                 run_local_api=False,
                 port=None):
        """
        Initialize an instance of the class SCIONDaemon.
        """
        super().__init__("sciond", conf_dir, host_addr=addr, port=port)
        # TODO replace by pathstore instance
        self.up_segments = PathSegmentDB(segment_ttl=self.SEGMENT_TTL,
                                         max_res_no=self.MAX_SEG_NO)
        self.down_segments = PathSegmentDB(segment_ttl=self.SEGMENT_TTL,
                                           max_res_no=self.MAX_SEG_NO)
        self.core_segments = PathSegmentDB(segment_ttl=self.SEGMENT_TTL,
                                           max_res_no=self.MAX_SEG_NO)
        self.peer_revs = RevCache()
        req_name = "SCIONDaemon Requests %s" % self.addr.isd_as
        self.requests = RequestHandler.start(
            req_name,
            self._check_segments,
            self._fetch_segments,
            self._reply_segments,
            ttl=self.TIMEOUT,
            key_map=self._req_key_map,
        )
        self._api_sock = None
        self.daemon_thread = None
        os.makedirs(SCIOND_API_SOCKDIR, exist_ok=True)
        self.api_addr = (api_addr or os.path.join(
            SCIOND_API_SOCKDIR, "%s.sock" % self.addr.isd_as))

        self.CTRL_PLD_CLASS_MAP = {
            PayloadClass.PATH: {
                PMT.REPLY: self.handle_path_reply,
                PMT.REVOCATION: self.handle_revocation,
            }
        }

        self.SCMP_PLD_CLASS_MAP = {
            SCMPClass.PATH: {
                SCMPPathClass.REVOKED_IF: self.handle_scmp_revocation
            },
        }

        if run_local_api:
            self._api_sock = ReliableSocket(bind=(self.api_addr, "sciond"))
            self._socks.add(self._api_sock, self.handle_accept)
Ejemplo n.º 17
0
    def __init__(self, conf_dir, addr, api_addr, run_local_api=False,
                 port=None, spki_cache_dir=GEN_CACHE_PATH, prom_export=None, delete_sock=False):
        """
        Initialize an instance of the class SCIONDaemon.
        """
        super().__init__("sciond", conf_dir, spki_cache_dir=spki_cache_dir,
                         prom_export=prom_export, public=[(addr, port)])
        up_labels = {**self._labels, "type": "up"} if self._labels else None
        down_labels = {**self._labels, "type": "down"} if self._labels else None
        core_labels = {**self._labels, "type": "core"} if self._labels else None
        self.up_segments = PathSegmentDB(segment_ttl=self.SEGMENT_TTL, labels=up_labels)
        self.down_segments = PathSegmentDB(segment_ttl=self.SEGMENT_TTL, labels=down_labels)
        self.core_segments = PathSegmentDB(segment_ttl=self.SEGMENT_TTL, labels=core_labels)
        self.rev_cache = RevCache()
        # Keep track of requested paths.
        self.requested_paths = ExpiringDict(self.MAX_REQS, PATH_REQ_TOUT)
        self.req_path_lock = threading.Lock()
        self._api_sock = None
        self.daemon_thread = None
        os.makedirs(SCIOND_API_SOCKDIR, exist_ok=True)
        self.api_addr = (api_addr or get_default_sciond_path())
        if delete_sock:
            try:
                os.remove(self.api_addr)
            except OSError as e:
                if e.errno != errno.ENOENT:
                    logging.error("Could not delete socket %s: %s" % (self.api_addr, e))

        self.CTRL_PLD_CLASS_MAP = {
            PayloadClass.PATH: {
                PMT.REPLY: self.handle_path_reply,
                PMT.REVOCATION: self.handle_revocation,
            },
            PayloadClass.CERT: {
                CertMgmtType.CERT_CHAIN_REQ: self.process_cert_chain_request,
                CertMgmtType.CERT_CHAIN_REPLY: self.process_cert_chain_reply,
                CertMgmtType.TRC_REPLY: self.process_trc_reply,
                CertMgmtType.TRC_REQ: self.process_trc_request,
            },
        }

        self.SCMP_PLD_CLASS_MAP = {
            SCMPClass.PATH:
                {SCMPPathClass.REVOKED_IF: self.handle_scmp_revocation},
        }

        if run_local_api:
            self._api_sock = ReliableSocket(bind_unix=(self.api_addr, "sciond"))
            self._socks.add(self._api_sock, self.handle_accept)
Ejemplo n.º 18
0
 def test(self):
     inst = PathSegmentDB()
     inst._parse_call_kwargs = create_mock()
     inst._parse_call_kwargs.return_value = {"arg1": "val1"}
     inst._exp_call_records = create_mock()
     inst._sort_call_pcbs = create_mock()
     inst._db = create_mock()
     # Call
     ntools.eq_(inst("data", a="b"), inst._sort_call_pcbs.return_value)
     # Tests
     inst._parse_call_kwargs.assert_called_once_with({"a": "b"})
     inst._db.assert_called_once_with("data", arg1="val1")
     inst._exp_call_records.assert_called_once_with(inst._db.return_value)
     inst._sort_call_pcbs.assert_called_once_with(
         False, inst._exp_call_records.return_value)
Ejemplo n.º 19
0
class LocalPathServer(PathServer):
    """
    SCION Path Server in a non-core AS. Stores up-segments to the core and
    registers down-segments with the CPS. Can cache segments learned from a CPS.
    """
    def __init__(self, server_id, conf_dir):
        """
        :param str server_id: server identifier.
        :param str conf_dir: configuration directory.
        """
        super().__init__(server_id, conf_dir)
        # Sanity check that we should indeed be a local path server.
        assert not self.topology.is_core_as, "This shouldn't be a core PS!"
        # Database of up-segments to the core.
        self.up_segments = PathSegmentDB(max_res_no=self.MAX_SEG_NO)

    def _handle_up_segment_record(self, pcb, from_zk=False):
        if not from_zk:
            self._segs_to_zk.append((PST.UP, pcb))
        if self._add_segment(pcb, self.up_segments, "Up"):
            # Sending pending targets to the core using first registered
            # up-segment.
            self._handle_waiting_targets(pcb)
            return set([(pcb.first_ia(), pcb.is_sibra())])
        return set()

    def _handle_down_segment_record(self, pcb, from_zk=None):
        if self._add_segment(pcb, self.down_segments, "Down"):
            return set([(pcb.last_ia(), pcb.is_sibra())])
        return set()

    def _handle_core_segment_record(self, pcb, from_zk=None):
        if self._add_segment(pcb, self.core_segments, "Core"):
            return set([(pcb.first_ia(), pcb.is_sibra())])
        return set()

    def _remove_revoked_segments(self, rev_info):
        """
        Remove segments that contain a revoked interface. Checks 20 tokens in
        case previous revocations were missed by the PS.

        :param rev_info: The revocation info
        :type rev_info: RevocationInfo
        """
        rev_token = rev_info.rev_token
        for _ in range(self.N_TOKENS_CHECK):
            segments = self.iftoken2seg[rev_token]
            while segments:
                sid = segments.pop()
                # Delete segment from DB.
                self.up_segments.delete(sid)
                self.down_segments.delete(sid)
                self.core_segments.delete(sid)
            if rev_token in self.iftoken2seg:
                del self.iftoken2seg[rev_token]
            rev_token = SHA256.new(rev_token).digest()

    def path_resolution(self, pkt, new_request=True):
        """
        Handle generic type of a path request.
        """
        req = pkt.get_payload()
        dst_ia = req.dst_ia()
        if new_request:
            logging.info("PATH_REQ received: %s", req.short_desc())
        if dst_ia == self.addr.isd_as:
            logging.warning("Dropping request: requested DST is local AS")
            return False
        up_segs = set()
        core_segs = set()
        down_segs = set()
        # dst as==0 means any core AS in the specified ISD
        if self.is_core_as(dst_ia) or dst_ia[1] == 0:
            self._resolve_core(req, up_segs, core_segs)
        else:
            self._resolve_not_core(req, up_segs, core_segs, down_segs)
        if (up_segs | core_segs | down_segs):
            self._send_path_segments(pkt, up_segs, core_segs, down_segs)
            return True
        if new_request:
            self._request_paths_from_core(req)
            self.pending_req[(dst_ia, req.p.flags.sibra)].append(pkt)
        else:
            # That could happend when needed segment expired.
            logging.warning("Handling pending request and needed seg "
                            "is missing. Shouldn't be here (too often).")
        return False

    def _resolve_core(self, req, up_segs, core_segs):
        """
        Dst is core AS.
        """
        dst_ia = req.dst_ia()
        params = dst_ia.params()
        params["sibra"] = req.p.flags.sibra
        if dst_ia[0] == self.addr.isd_as[0]:
            # Dst in local ISD. First check whether DST is a (super)-parent.
            up_segs.update(self.up_segments(**params))
        # Check whether dst is known core AS.
        for cseg in self.core_segments(**params):
            # Check do we have an up-seg that is connected to core_seg.
            tmp_up_segs = self.up_segments(first_ia=cseg.last_ia(),
                                           sibra=req.p.flags.sibra)
            if tmp_up_segs:
                up_segs.update(tmp_up_segs)
                core_segs.add(cseg)

    def _resolve_not_core(self, req, up_segs, core_segs, down_segs):
        """
        Dst is regular AS.
        """
        sibra = req.p.flags.sibra
        # Check if there exists down-seg to DST.
        for dseg in self.down_segments(last_ia=req.dst_ia(), sibra=sibra):
            first_ia = dseg.first_ia()
            if req.dst_ia()[0] == self.addr.isd_as[0]:
                # Dst in local ISD. First try to find direct up-seg.
                dir_up_segs = self.up_segments(first_ia=first_ia, sibra=sibra)
                if dir_up_segs:
                    up_segs.update(dir_up_segs)
                    down_segs.add(dseg)
            # Now try core segments that connect to down segment.
            # PSz: it might make sense to start with up_segments instead.
            for cseg in self.core_segments(first_ia=first_ia, sibra=sibra):
                # And up segments that connect to core segment.
                up_core_segs = self.up_segments(first_ia=cseg.last_ia(),
                                                sibra=sibra)
                if up_core_segs:
                    up_segs.update(up_core_segs)
                    core_segs.add(cseg)
                    down_segs.add(dseg)

    def _request_paths_from_core(self, req):
        """
        Try to request core PS for given target.
        """
        up_segs = self.up_segments(sibra=req.p.flags.sibra)
        if not up_segs:
            logging.info('Pending target added for %s', req.short_desc())
            # Wait for path to any local core AS
            self.waiting_targets[self.addr.isd_as[0]].append(req)
            return

        # PSz: for multipath it makes sense to query with multiple core ASes
        pcb = up_segs[0]
        logging.info('Send request to core (%s) via %s', req.short_desc(),
                     pcb.short_desc())
        path = pcb.get_path(reverse_direction=True)
        req_pkt = self._build_packet(SVCType.PS,
                                     payload=req.copy(),
                                     path=path,
                                     dst_ia=pcb.first_ia())
        self._send_to_next_hop(req_pkt, path.get_fwd_if())
Ejemplo n.º 20
0
class SCIONDaemon(SCIONElement):
    """
    The SCION Daemon used for retrieving and combining paths.
    """
    # Max time for a path lookup to succeed/fail.
    TIMEOUT = 5
    # Time a path segment is cached at a host (in seconds).
    SEGMENT_TTL = 300
    MAX_SEG_NO = 5  # TODO: replace by config variable.

    def __init__(self,
                 conf_dir,
                 addr,
                 api_addr,
                 run_local_api=False,
                 port=None):
        """
        Initialize an instance of the class SCIONDaemon.
        """
        super().__init__("sciond", conf_dir, host_addr=addr, port=port)
        # TODO replace by pathstore instance
        self.up_segments = PathSegmentDB(segment_ttl=self.SEGMENT_TTL,
                                         max_res_no=self.MAX_SEG_NO)
        self.down_segments = PathSegmentDB(segment_ttl=self.SEGMENT_TTL,
                                           max_res_no=self.MAX_SEG_NO)
        self.core_segments = PathSegmentDB(segment_ttl=self.SEGMENT_TTL,
                                           max_res_no=self.MAX_SEG_NO)
        req_name = "SCIONDaemon Requests %s" % self.addr.isd_as
        self.requests = RequestHandler.start(
            req_name,
            self._check_segments,
            self._fetch_segments,
            self._reply_segments,
            ttl=self.TIMEOUT,
            key_map=self._req_key_map,
        )
        self._api_sock = None
        self.daemon_thread = None
        os.makedirs(SCIOND_API_SOCKDIR, exist_ok=True)
        self.api_addr = (api_addr or os.path.join(
            SCIOND_API_SOCKDIR, "%s.sock" % self.addr.isd_as))

        self.CTRL_PLD_CLASS_MAP = {
            PayloadClass.PATH: {
                PMT.REPLY: self.handle_path_reply,
                PMT.REVOCATION: self.handle_revocation,
            }
        }

        self.SCMP_PLD_CLASS_MAP = {
            SCMPClass.PATH: {
                SCMPPathClass.REVOKED_IF: self.handle_scmp_revocation
            },
        }

        if run_local_api:
            self._api_sock = ReliableSocket(bind=(self.api_addr, "sciond"))
            self._socks.add(self._api_sock, self.handle_accept)

    @classmethod
    def start(cls, conf_dir, addr, api_addr=None, run_local_api=False, port=0):
        """
        Initializes, starts, and returns a SCIONDaemon object.

        Example of usage:
        sd = SCIONDaemon.start(conf_dir, addr)
        paths = sd.get_paths(isd_as)
        """
        inst = cls(conf_dir, addr, api_addr, run_local_api, port)
        name = "SCIONDaemon.run %s" % inst.addr.isd_as
        inst.daemon_thread = threading.Thread(target=thread_safety_net,
                                              args=(inst.run, ),
                                              name=name,
                                              daemon=True)
        inst.daemon_thread.start()
        logging.debug("sciond started with api_addr = %s", inst.api_addr)
        return inst

    def _get_msg_meta(self, packet, addr, sock):
        if sock != self._udp_sock:
            return packet, SockOnlyMetadata.from_values(sock)  # API socket
        else:
            return super()._get_msg_meta(packet, addr, sock)

    def handle_msg_meta(self, msg, meta):
        """
        Main routine to handle incoming SCION messages.
        """
        if isinstance(meta, SockOnlyMetadata):  # From SCIOND API
            self.api_handle_request(msg, meta)
            return
        logging.debug("handle_msg_meta()")
        super().handle_msg_meta(msg, meta)

    def handle_path_reply(self, path_reply, meta):
        """
        Handle path reply from local path server.
        """
        added = set()
        map_ = {
            PST.UP: self._handle_up_seg,
            PST.DOWN: self._handle_down_seg,
            PST.CORE: self._handle_core_seg,
        }
        for type_, pcb in path_reply.iter_pcbs():
            ret = map_[type_](pcb)
            if not ret:
                continue
            flags = (PATH_FLAG_SIBRA, ) if pcb.is_sibra() else ()
            added.add((ret, flags))
        logging.debug("Added: %s", added)
        for dst_ia, flags in added:
            self.requests.put(((dst_ia, flags), None))
        logging.debug("Closing meta")
        meta.close()

    def _handle_up_seg(self, pcb):
        if self.addr.isd_as != pcb.last_ia():
            return None
        if self.up_segments.update(pcb) == DBResult.ENTRY_ADDED:
            logging.debug("Up segment added: %s", pcb.short_desc())
            return pcb.first_ia()
        return None

    def _handle_down_seg(self, pcb):
        last_ia = pcb.last_ia()
        if self.addr.isd_as == last_ia:
            return None
        if self.down_segments.update(pcb) == DBResult.ENTRY_ADDED:
            logging.debug("Down segment added: %s", pcb.short_desc())
            return last_ia
        return None

    def _handle_core_seg(self, pcb):
        if self.core_segments.update(pcb) == DBResult.ENTRY_ADDED:
            logging.debug("Core segment added: %s", pcb.short_desc())
            return pcb.first_ia()
        return None

    def api_handle_request(self, msg, meta):
        """
        Handle local API's requests.
        """
        if msg[0] == 0:  # path request
            logging.debug('API: path request')
            threading.Thread(target=thread_safety_net,
                             args=(self._api_handle_path_request, msg, meta),
                             daemon=True).start()
        elif msg[0] == 1:  # address request
            logging.debug('API: local ISD-AS request')
            self.send_meta(self.addr.isd_as.pack(), meta)
        else:
            logging.warning("API: type %d not supported.", msg[0])

    def _api_handle_path_request(self, msg, meta):
        """
        Path request:
          | \x00 (1B) | ISD (12bits) |  AS (20bits)  |
        Reply:
          |p1_len(1B)|p1((p1_len*8)B)|fh_type(1B)|fh_IP(?B)|fh_port(2B)|mtu(2B)|
           p1_if_count(1B)|p1_if_1(5B)|...|p1_if_n(5B)|
           p2_len(1B)|...
         or b"" when no path found.
        """
        dst_ia = ISD_AS(msg[1:ISD_AS.LEN + 1])
        thread = threading.current_thread()
        thread.name = "SCIONDaemon API id:%s %s -> %s" % (
            thread.ident, self.addr.isd_as, dst_ia)
        paths = self.get_paths(dst_ia)
        reply = []
        logging.debug("Replying to api request for %s with %d paths", dst_ia,
                      len(paths))
        for path in paths:
            raw_path = path.pack()
            fwd_if = path.get_fwd_if()
            # Set dummy host addr if path is empty.
            if fwd_if == 0:
                haddr, port = HostAddrNone(), SCION_UDP_EH_DATA_PORT
            else:
                br = self.ifid2br[fwd_if]
                haddr, port = br.addr, br.port
            path_len = len(raw_path) // 8
            reply.append(
                struct.pack("!B", path_len) + raw_path +
                struct.pack("!B", haddr.TYPE) + haddr.pack() +
                struct.pack("!H", port) + struct.pack("!H", path.mtu) +
                struct.pack("!B", len(path.interfaces)))
            for interface in path.interfaces:
                isd_as, link = interface
                reply.append(isd_as.pack())
                reply.append(struct.pack("!H", link))
        self.send_meta(b"".join(reply), meta)

    def handle_scmp_revocation(self, pld, meta):
        rev_info = RevocationInfo.from_raw(pld.info.rev_info)
        self.handle_revocation(rev_info, meta)

    def handle_revocation(self, rev_info, meta):
        assert isinstance(rev_info, RevocationInfo)
        if not self._validate_revocation(rev_info):
            return
        # Go through all segment databases and remove affected segments.
        removed_up = self._remove_revoked_pcbs(self.up_segments, rev_info)
        removed_core = self._remove_revoked_pcbs(self.core_segments, rev_info)
        removed_down = self._remove_revoked_pcbs(self.down_segments, rev_info)
        logging.info("Removed %d UP- %d CORE- and %d DOWN-Segments." %
                     (removed_up, removed_core, removed_down))

    def _remove_revoked_pcbs(self, db, rev_info):
        """
        Removes all segments from 'db' that contain an IF token for which
        rev_token is a preimage (within 20 calls).

        :param db: The PathSegmentDB.
        :type db: :class:`lib.path_db.PathSegmentDB`
        :param rev_info: The revocation info
        :type rev_info: RevocationInfo

        :returns: The number of deletions.
        :rtype: int
        """

        if not ConnectedHashTree.verify_epoch(rev_info.p.epoch):
            logging.debug(
                "Failed to verify epoch: rev_info epoch %d,current epoch %d." %
                (rev_info.p.epoch, ConnectedHashTree.get_current_epoch()))
            return 0

        to_remove = []
        for segment in db(full=True):
            for asm in segment.iter_asms():
                if self._verify_revocation_for_asm(rev_info, asm):
                    logging.debug("Removing segment: %s" %
                                  segment.short_desc())
                    to_remove.append(segment.get_hops_hash())
        return db.delete_all(to_remove)

    def get_paths(self, dst_ia, flags=()):
        """Return a list of paths."""
        logging.debug("Paths requested for %s %s", dst_ia, flags)
        if self.addr.isd_as == dst_ia or (self.addr.isd_as.any_as() == dst_ia
                                          and self.topology.is_core_as):
            # Either the destination is the local AS, or the destination is any
            # core AS in this ISD, and the local AS is in the core
            empty = SCIONPath()
            empty.mtu = self.topology.mtu
            return [empty]
        deadline = SCIONTime.get_time() + self.TIMEOUT
        e = threading.Event()
        self.requests.put(((dst_ia, flags), e))
        if not self._wait_for_events([e], deadline):
            logging.error("Query timed out for %s", dst_ia)
            return []
        return self.path_resolution(dst_ia, flags=flags)

    def path_resolution(self, dst_ia, flags=()):
        # dst as == 0 means any core AS in the specified ISD.
        dst_is_core = self.is_core_as(dst_ia) or dst_ia[1] == 0
        sibra = PATH_FLAG_SIBRA in flags
        if self.topology.is_core_as:
            if dst_is_core:
                ret = self._resolve_core_core(dst_ia, sibra=sibra)
            else:
                ret = self._resolve_core_not_core(dst_ia, sibra=sibra)
        elif dst_is_core:
            ret = self._resolve_not_core_core(dst_ia, sibra=sibra)
        elif sibra:
            ret = self._resolve_not_core_not_core_sibra(dst_ia)
        else:
            ret = self._resolve_not_core_not_core_scion(dst_ia)
        if not sibra:
            return ret
        # FIXME(kormat): Strip off PCBs, and just return sibra reservation
        # blocks
        return self._sibra_strip_pcbs(self._strip_nones(ret))

    def _resolve_core_core(self, dst_ia, sibra=False):
        """Resolve path from core to core."""
        res = set()
        for cseg in self.core_segments(last_ia=self.addr.isd_as,
                                       sibra=sibra,
                                       **dst_ia.params()):
            res.add((None, cseg, None))
        if sibra:
            return res
        return PathCombinator.tuples_to_full_paths(res)

    def _resolve_core_not_core(self, dst_ia, sibra=False):
        """Resolve path from core to non-core."""
        res = set()
        # First check whether there is a direct path.
        for dseg in self.down_segments(first_ia=self.addr.isd_as,
                                       last_ia=dst_ia,
                                       sibra=sibra):
            res.add((None, None, dseg))
        # Check core-down combination.
        for dseg in self.down_segments(last_ia=dst_ia, sibra=sibra):
            dseg_ia = dseg.first_ia()
            if self.addr.isd_as == dseg_ia:
                pass
            for cseg in self.core_segments(first_ia=dseg_ia,
                                           last_ia=self.addr.isd_as,
                                           sibra=sibra):
                res.add((None, cseg, dseg))
        if sibra:
            return res
        return PathCombinator.tuples_to_full_paths(res)

    def _resolve_not_core_core(self, dst_ia, sibra=False):
        """Resolve path from non-core to core."""
        res = set()
        params = dst_ia.params()
        params["sibra"] = sibra
        if dst_ia[0] == self.addr.isd_as[0]:
            # Dst in local ISD. First check whether DST is a (super)-parent.
            for useg in self.up_segments(**params):
                res.add((useg, None, None))
        # Check whether dst is known core AS.
        for cseg in self.core_segments(**params):
            # Check do we have an up-seg that is connected to core_seg.
            for useg in self.up_segments(first_ia=cseg.last_ia(), sibra=sibra):
                res.add((useg, cseg, None))
        if sibra:
            return res
        return PathCombinator.tuples_to_full_paths(res)

    def _resolve_not_core_not_core_scion(self, dst_ia):
        """Resolve SCION path from non-core to non-core."""
        up_segs = self.up_segments()
        down_segs = self.down_segments(last_ia=dst_ia)
        core_segs = self._calc_core_segs(dst_ia[0], up_segs, down_segs)
        full_paths = PathCombinator.build_shortcut_paths(up_segs, down_segs)
        tuples = []
        for up_seg in up_segs:
            for down_seg in down_segs:
                tuples.append((up_seg, None, down_seg))
                for core_seg in core_segs:
                    tuples.append((up_seg, core_seg, down_seg))
        full_paths.extend(PathCombinator.tuples_to_full_paths(tuples))
        return full_paths

    def _resolve_not_core_not_core_sibra(self, dst_ia):
        """Resolve SIBRA path from non-core to non-core."""
        res = set()
        up_segs = set(self.up_segments(sibra=True))
        down_segs = set(self.down_segments(last_ia=dst_ia, sibra=True))
        for up_seg, down_seg in product(up_segs, down_segs):
            src_core_ia = up_seg.first_ia()
            dst_core_ia = down_seg.first_ia()
            if src_core_ia == dst_core_ia:
                res.add((up_seg, down_seg))
                continue
            for core_seg in self.core_segments(first_ia=dst_core_ia,
                                               last_ia=src_core_ia,
                                               sibra=True):
                res.add((up_seg, core_seg, down_seg))
        return res

    def _strip_nones(self, set_):
        """Strip None entries from a set of tuples"""
        res = []
        for tup in set_:
            res.append(tuple(filter(None, tup)))
        return res

    def _sibra_strip_pcbs(self, paths):
        ret = []
        for pcbs in paths:
            resvs = []
            for pcb in pcbs:
                resvs.append(self._sibra_strip_pcb(pcb))
            ret.append(resvs)
        return ret

    def _sibra_strip_pcb(self, pcb):
        assert pcb.is_sibra()
        pcb_ext = pcb.sibra_ext
        resv_info = pcb_ext.info
        resv = ResvBlockSteady.from_values(resv_info, pcb.get_n_hops())
        asms = pcb.iter_asms()
        if pcb_ext.p.up:
            asms = reversed(list(asms))
        iflist = []
        for sof, asm in zip(pcb_ext.iter_sofs(), asms):
            resv.sofs.append(sof)
            iflist.extend(
                self._sibra_add_ifs(asm.isd_as(), sof, resv_info.fwd_dir))
        assert resv.num_hops == len(resv.sofs)
        return pcb_ext.p.id, resv, iflist

    def _sibra_add_ifs(self, isd_as, sof, fwd):
        def _add(ifid):
            if ifid:
                ret.append((isd_as, ifid))

        ret = []
        if fwd:
            _add(sof.ingress)
            _add(sof.egress)
        else:
            _add(sof.egress)
            _add(sof.ingress)
        return ret

    def _wait_for_events(self, events, deadline):
        """
        Wait on a set of events, but only until the specified deadline. Returns
        the number of events that happened while waiting.
        """
        count = 0
        for e in events:
            if e.wait(max(0, deadline - SCIONTime.get_time())):
                count += 1
        return count

    def _check_segments(self, key):
        """
        Called by RequestHandler to check if a given path request can be
        fulfilled.
        """
        dst_ia, flags = key
        return self.path_resolution(dst_ia, flags=flags)

    def _fetch_segments(self, key, _):
        """
        Called by RequestHandler to fetch the requested path.
        """
        dst_ia, flags = key
        try:
            addr, port = self.dns_query_topo(PATH_SERVICE)[0]
        except SCIONServiceLookupError:
            log_exception("Error querying path service:")
            return
        req = PathSegmentReq.from_values(self.addr.isd_as, dst_ia, flags=flags)
        logging.debug("Sending path request: %s", req.short_desc())
        meta = self.DefaultMeta.from_values(host=addr, port=port)
        self.send_meta(req, meta)

    def _reply_segments(self, key, e):
        """
        Called by RequestHandler to signal that the request has been fulfilled.
        """
        e.set()

    def _req_key_map(self, key, req_keys):
        """
        Called by RequestHandler to know which requests can be answered by
        `key`.
        """
        ans_ia, ans_flags = key
        ans_f_set = set(ans_flags)
        ret = []
        for req_ia, req_flags in req_keys:
            req_f_set = set(req_flags)
            if req_f_set != ans_f_set and (not ans_f_set & req_f_set):
                # The answer and the request have no flags in common, so skip
                # it.
                continue
            if (req_ia == ans_ia) or (req_ia == ans_ia.any_as()):
                # Covers the case where a request was for ISD-0 (i.e. any path
                # to a core AS in the specified ISD)
                ret.append((req_ia, req_flags))
        return ret

    def _calc_core_segs(self, dst_isd, up_segs, down_segs):
        """
        Calculate all possible core segments joining the provided up and down
        segments. Returns a list of all known segments, and a seperate list of
        the missing AS pairs.
        """
        src_core_ases = set()
        dst_core_ases = set()
        for seg in up_segs:
            src_core_ases.add(seg.first_ia()[1])
        for seg in down_segs:
            dst_core_ases.add(seg.first_ia()[1])
        # Generate all possible AS pairs
        as_pairs = list(product(src_core_ases, dst_core_ases))
        return self._find_core_segs(self.addr.isd_as[0], dst_isd, as_pairs)

    def _find_core_segs(self, src_isd, dst_isd, as_pairs):
        """
        Given a set of AS pairs across 2 ISDs, return the core segments
        connecting those pairs
        """
        core_segs = []
        for src_core_as, dst_core_as in as_pairs:
            src_ia = ISD_AS.from_values(src_isd, src_core_as)
            dst_ia = ISD_AS.from_values(dst_isd, dst_core_as)
            if src_ia == dst_ia:
                continue
            seg = self.core_segments(first_ia=dst_ia, last_ia=src_ia)
            if seg:
                core_segs.extend(seg)
        return core_segs
Ejemplo n.º 21
0
class PathServer(SCIONElement, metaclass=ABCMeta):
    """
    The SCION Path Server.
    """
    SERVICE_TYPE = PATH_SERVICE
    MAX_SEG_NO = 5  # TODO: replace by config variable.
    # ZK path for incoming PATHs
    ZK_PATH_CACHE_PATH = "path_cache"
    # ZK path for incoming REVs
    ZK_REV_CACHE_PATH = "rev_cache"
    # Max number of segments per propagation packet
    PROP_LIMIT = 5
    # Max number of segments per ZK cache entry
    ZK_SHARE_LIMIT = 10
    # Time to store revocations in zookeeper
    ZK_REV_OBJ_MAX_AGE = HASHTREE_EPOCH_TIME

    def __init__(self, server_id, conf_dir):
        """
        :param str server_id: server identifier.
        :param str conf_dir: configuration directory.
        """
        super().__init__(server_id, conf_dir)
        self.down_segments = PathSegmentDB(max_res_no=self.MAX_SEG_NO)
        self.core_segments = PathSegmentDB(max_res_no=self.MAX_SEG_NO)
        self.pending_req = defaultdict(list)  # Dict of pending requests.
        self.pen_req_lock = threading.Lock()
        # Used when l/cPS doesn't have up/dw-path.
        self.waiting_targets = defaultdict(list)
        self.revocations = RevCache()
        # A mapping from (hash tree root of AS, IFID) to segments
        self.htroot_if2seg = ExpiringDict(1000, HASHTREE_TTL)
        self.htroot_if2seglock = Lock()
        self.CTRL_PLD_CLASS_MAP = {
            PayloadClass.PATH: {
                PMT.REQUEST: self.path_resolution,
                PMT.REPLY: self.handle_path_segment_record,
                PMT.REG: self.handle_path_segment_record,
                PMT.REVOCATION: self._handle_revocation,
                PMT.SYNC: self.handle_path_segment_record,
            },
            PayloadClass.CERT: {
                CertMgmtType.CERT_CHAIN_REQ: self.process_cert_chain_request,
                CertMgmtType.CERT_CHAIN_REPLY: self.process_cert_chain_reply,
                CertMgmtType.TRC_REPLY: self.process_trc_reply,
                CertMgmtType.TRC_REQ: self.process_trc_request,
            },
        }
        self.SCMP_PLD_CLASS_MAP = {
            SCMPClass.PATH: {
                SCMPPathClass.REVOKED_IF: self._handle_scmp_revocation,
            },
        }
        self._segs_to_zk = deque()
        self._revs_to_zk = deque()
        self._zkid = ZkID.from_values(self.addr.isd_as, self.id,
                                      [(self.addr.host, self._port)])
        self.zk = Zookeeper(self.topology.isd_as, PATH_SERVICE,
                            self._zkid.copy().pack(), self.topology.zookeepers)
        self.zk.retry("Joining party", self.zk.party_setup)
        self.path_cache = ZkSharedCache(self.zk, self.ZK_PATH_CACHE_PATH,
                                        self._handle_paths_from_zk)
        self.rev_cache = ZkSharedCache(self.zk, self.ZK_REV_CACHE_PATH,
                                       self._rev_entries_handler)

    def worker(self):
        """
        Worker thread that takes care of reading shared paths from ZK, and
        handling master election for core servers.
        """
        worker_cycle = 1.0
        start = SCIONTime.get_time()
        was_master = False
        while self.run_flag.is_set():
            sleep_interval(start, worker_cycle, "cPS.worker cycle",
                           self._quiet_startup())
            start = SCIONTime.get_time()
            try:
                self.zk.wait_connected()
                self.path_cache.process()
                self.rev_cache.process()
                # Try to become a master.
                is_master = self.zk.get_lock(lock_timeout=0, conn_timeout=0)
                if is_master:
                    if not was_master:
                        logging.info("Became master")
                    self.path_cache.expire(self.config.propagation_time * 10)
                    self.rev_cache.expire(self.ZK_REV_OBJ_MAX_AGE)
                    was_master = True
                else:
                    was_master = False
            except ZkNoConnection:
                logging.warning('worker(): ZkNoConnection')
                pass
            self._update_master()
            self._propagate_and_sync()
            self._handle_pending_requests()

    def _update_master(self):
        pass

    def _rev_entries_handler(self, raw_entries):
        for raw in raw_entries:
            rev_info = RevocationInfo.from_raw(raw)
            self._remove_revoked_segments(rev_info)

    def _add_rev_mappings(self, pcb):
        """
        Add if revocation token to segment ID mappings.
        """
        segment_id = pcb.get_hops_hash()
        with self.htroot_if2seglock:
            for asm in pcb.iter_asms():
                hof = asm.pcbm(0).hof()
                egress_h = (asm.p.hashTreeRoot, hof.egress_if)
                self.htroot_if2seg.setdefault(egress_h, set()).add(segment_id)
                ingress_h = (asm.p.hashTreeRoot, hof.ingress_if)
                self.htroot_if2seg.setdefault(ingress_h, set()).add(segment_id)

    @abstractmethod
    def _handle_up_segment_record(self, pcb, **kwargs):
        raise NotImplementedError

    @abstractmethod
    def _handle_down_segment_record(self, pcb, **kwargs):
        raise NotImplementedError

    @abstractmethod
    def _handle_core_segment_record(self, pcb, **kwargs):
        raise NotImplementedError

    def _add_segment(self, pcb, seg_db, name, reverse=False):
        res = seg_db.update(pcb, reverse=reverse)
        if res == DBResult.ENTRY_ADDED:
            self._add_rev_mappings(pcb)
            logging.info("%s-Segment registered: %s", name, pcb.short_desc())
            return True
        elif res == DBResult.ENTRY_UPDATED:
            self._add_rev_mappings(pcb)
            logging.debug("%s-Segment updated: %s", name, pcb.short_desc())
        return False

    def _handle_scmp_revocation(self, pld, meta):
        rev_info = RevocationInfo.from_raw(pld.info.rev_info)
        self._handle_revocation(rev_info, meta)

    def _handle_revocation(self, rev_info, meta):
        """
        Handles a revocation of a segment, interface or hop.

        :param rev_info: The RevocationInfo object.
        """
        assert isinstance(rev_info, RevocationInfo)
        if not self._validate_revocation(rev_info):
            return
        if meta.ia[0] != self.addr.isd_as[0]:
            logging.info("Dropping revocation received from a different ISD.")
            return

        if rev_info in self.revocations:
            logging.debug("Already received revocation. Dropping...")
            return False
        self.revocations.add(rev_info)
        logging.debug("Received revocation from %s:\n%s", meta.get_addr(),
                      rev_info)
        self._revs_to_zk.append(rev_info.copy().pack())  # have to pack copy
        # Remove segments that contain the revoked interface.
        self._remove_revoked_segments(rev_info)
        # Forward revocation to other path servers.
        self._forward_revocation(rev_info, meta)

    def _remove_revoked_segments(self, rev_info):
        """
        Try the previous and next hashes as possible astokens,
        and delete any segment that matches

        :param rev_info: The revocation info
        :type rev_info: RevocationInfo
        """
        if not ConnectedHashTree.verify_epoch(rev_info.p.epoch):
            return
        (hash01, hash12) = ConnectedHashTree.get_possible_hashes(rev_info)
        if_id = rev_info.p.ifID

        with self.htroot_if2seglock:
            down_segs_removed = 0
            core_segs_removed = 0
            up_segs_removed = 0
            for h in (hash01, hash12):
                for sid in self.htroot_if2seg.pop((h, if_id), []):
                    if self.down_segments.delete(
                            sid) == DBResult.ENTRY_DELETED:
                        down_segs_removed += 1
                    if self.core_segments.delete(
                            sid) == DBResult.ENTRY_DELETED:
                        core_segs_removed += 1
                    if not self.topology.is_core_as:
                        if (self.up_segments.delete(sid) ==
                                DBResult.ENTRY_DELETED):
                            up_segs_removed += 1
            logging.info(
                "Removed segments containing IF %d: "
                "UP: %d DOWN: %d CORE: %d" %
                (if_id, up_segs_removed, down_segs_removed, core_segs_removed))

    @abstractmethod
    def _forward_revocation(self, rev_info, meta):
        """
        Forwards a revocation to other path servers that need to be notified.

        :param rev_info: The RevInfo object.
        :param meta: The MessageMeta object.
        """
        raise NotImplementedError

    def _send_path_segments(self, req, meta, up=None, core=None, down=None):
        """
        Sends path-segments to requester (depending on Path Server's location).
        """
        up = up or set()
        core = core or set()
        down = down or set()
        all_segs = up | core | down
        if not all_segs:
            logging.warning("No segments to send")
            return
        revs_to_add = self._peer_revs_for_segs(all_segs)
        pld = PathRecordsReply.from_values(
            {
                PST.UP: up,
                PST.CORE: core,
                PST.DOWN: down
            }, revs_to_add)
        self.send_meta(pld, meta)
        logging.info(
            "Sending PATH_REPLY with %d segment(s) to:%s "
            "port:%s in response to: %s",
            len(all_segs),
            meta.get_addr(),
            meta.port,
            req.short_desc(),
        )

    def _peer_revs_for_segs(self, segs):
        """Returns a list of peer revocations for segments in 'segs'."""
        def _handle_one_seg(seg):
            for asm in seg.iter_asms():
                for pcbm in asm.iter_pcbms(1):
                    hof = pcbm.hof()
                    for if_id in [hof.ingress_if, hof.egress_if]:
                        rev_info = self.revocations.get((asm.isd_as(), if_id))
                        if rev_info:
                            revs_to_add.add(rev_info.copy())
                            return

        revs_to_add = set()
        for seg in segs:
            _handle_one_seg(seg)

        return list(revs_to_add)

    def _handle_pending_requests(self):
        rem_keys = []
        # Serve pending requests.
        with self.pen_req_lock:
            for key in self.pending_req:
                to_remove = []
                for req, meta in self.pending_req[key]:
                    if self.path_resolution(req, meta, new_request=False):
                        meta.close()
                        to_remove.append((req, meta))
                # Clean state.
                for req_meta in to_remove:
                    self.pending_req[key].remove(req_meta)
                if not self.pending_req[key]:
                    rem_keys.append(key)
            for key in rem_keys:
                del self.pending_req[key]

    def _handle_paths_from_zk(self, raw_entries):
        """
        Handles cached paths through ZK, passed as a list.
        """
        for raw in raw_entries:
            recs = PathSegmentRecords.from_raw(raw)
            for type_, pcb in recs.iter_pcbs():
                seg_meta = PathSegMeta(pcb,
                                       self.continue_seg_processing,
                                       type_=type_,
                                       params={'from_zk': True})
                self.process_path_seg(seg_meta)
        logging.debug("Processed %s segments from ZK", len(raw_entries))

    def handle_path_segment_record(self, seg_recs, meta):
        """
        Handles paths received from the network.
        """
        params = self._dispatch_params(seg_recs, meta)
        # Add revocations for peer interfaces included in the path segments.
        for rev_info in seg_recs.iter_rev_infos():
            self.revocations.add(rev_info)
        # Verify pcbs and process them
        for type_, pcb in seg_recs.iter_pcbs():
            seg_meta = PathSegMeta(pcb, self.continue_seg_processing, meta,
                                   type_, params)
            self.process_path_seg(seg_meta)

    def continue_seg_processing(self, seg_meta):
        """
        For every path segment(that can be verified) received from the network
        or ZK this function gets called to continue the processing for the
        segment.
        The segment is added to pathdb and pending requests are checked.
        """
        pcb = seg_meta.seg
        type_ = seg_meta.type
        params = seg_meta.params
        self._dispatch_segment_record(type_, pcb, **params)
        self._handle_pending_requests()

    def _dispatch_segment_record(self, type_, seg, **kwargs):
        # Check that segment does not contain a revoked interface.
        if not self._validate_segment(seg):
            logging.debug("Not adding segment due to revoked interface:\n%s" %
                          seg.short_desc())
            return
        handle_map = {
            PST.UP: self._handle_up_segment_record,
            PST.CORE: self._handle_core_segment_record,
            PST.DOWN: self._handle_down_segment_record,
        }
        handle_map[type_](seg, **kwargs)

    def _validate_segment(self, seg):
        """
        Check segment for revoked upstream/downstream interfaces.

        :param seg: The PathSegment object.
        :return: False, if the path segment contains a revoked upstream/
            downstream interface (not peer). True otherwise.
        """
        for asm in seg.iter_asms():
            pcbm = asm.pcbm(0)
            for if_id in [pcbm.p.inIF, pcbm.p.outIF]:
                rev_info = self.revocations.get((asm.isd_as(), if_id))
                if rev_info:
                    logging.debug("Found revoked interface (%d) in segment "
                                  "%s." % (rev_info.p.ifID, seg.short_desc()))
                    return False
        return True

    def _dispatch_params(self, pld, meta):
        return {}

    def _propagate_and_sync(self):
        self._share_via_zk()
        self._share_revs_via_zk()

    def _gen_prop_recs(self, queue, limit=PROP_LIMIT):
        count = 0
        pcbs = defaultdict(list)
        while queue:
            count += 1
            type_, pcb = queue.popleft()
            pcbs[type_].append(pcb.copy())
            if count >= limit:
                yield (pcbs)
                count = 0
                pcbs = defaultdict(list)
        if pcbs:
            yield (pcbs)

    @abstractmethod
    def path_resolution(self, path_request, meta, new_request):
        """
        Handles all types of path request.
        """
        raise NotImplementedError

    def _handle_waiting_targets(self, pcb):
        """
        Handle any queries that are waiting for a path to any core AS in an ISD.
        """
        dst_ia = pcb.first_ia()
        if not self.is_core_as(dst_ia):
            logging.warning("Invalid waiting target, not a core AS: %s",
                            dst_ia)
            return
        self._send_waiting_queries(dst_ia[0], pcb)

    def _send_waiting_queries(self, dst_isd, pcb):
        targets = self.waiting_targets[dst_isd]
        if not targets:
            return
        path = pcb.get_path(reverse_direction=True)
        src_ia = pcb.first_ia()
        while targets:
            seg_req = targets.pop(0)
            meta = self.DefaultMeta.from_values(ia=src_ia,
                                                path=path,
                                                host=SVCType.PS_A)
            self.send_meta(seg_req, meta)
            logging.info("Waiting request (%s) sent via %s",
                         seg_req.short_desc(), pcb.short_desc())

    def _share_via_zk(self):
        if not self._segs_to_zk:
            return
        logging.info("Sharing %d segment(s) via ZK", len(self._segs_to_zk))
        for pcb_dict in self._gen_prop_recs(self._segs_to_zk,
                                            limit=self.ZK_SHARE_LIMIT):
            seg_recs = PathSegmentRecords.from_values(pcb_dict)
            self._zk_write(seg_recs.pack())

    def _share_revs_via_zk(self):
        if not self._revs_to_zk:
            return
        logging.info("Sharing %d revocation(s) via ZK", len(self._revs_to_zk))
        while self._revs_to_zk:
            self._zk_write_rev(self._revs_to_zk.popleft())

    def _zk_write(self, data):
        hash_ = SHA256.new(data).hexdigest()
        try:
            self.path_cache.store("%s-%s" % (hash_, SCIONTime.get_time()),
                                  data)
        except ZkNoConnection:
            logging.warning("Unable to store segment(s) in shared path: "
                            "no connection to ZK")

    def _zk_write_rev(self, data):
        hash_ = SHA256.new(data).hexdigest()
        try:
            self.rev_cache.store("%s-%s" % (hash_, SCIONTime.get_time()), data)
        except ZkNoConnection:
            logging.warning("Unable to store revocation(s) in shared path: "
                            "no connection to ZK")

    def run(self):
        """
        Run an instance of the Path Server.
        """
        threading.Thread(target=thread_safety_net,
                         args=(self.worker, ),
                         name="PS.worker",
                         daemon=True).start()
        super().run()
Ejemplo n.º 22
0
class SCIONDaemon(SCIONElement):
    """
    The SCION Daemon used for retrieving and combining paths.
    """
    # Max time for a path lookup to succeed/fail.
    PATH_REQ_TOUT = 2
    MAX_REQS = 1024
    # Time a path segment is cached at a host (in seconds).
    SEGMENT_TTL = 300

    def __init__(self,
                 conf_dir,
                 addr,
                 api_addr,
                 run_local_api=False,
                 port=None):
        """
        Initialize an instance of the class SCIONDaemon.
        """
        super().__init__("sciond", conf_dir, host_addr=addr, port=port)
        # TODO replace by pathstore instance
        self.up_segments = PathSegmentDB(segment_ttl=self.SEGMENT_TTL)
        self.down_segments = PathSegmentDB(segment_ttl=self.SEGMENT_TTL)
        self.core_segments = PathSegmentDB(segment_ttl=self.SEGMENT_TTL)
        self.peer_revs = RevCache()
        # Keep track of requested paths.
        self.requested_paths = ExpiringDict(self.MAX_REQS, self.PATH_REQ_TOUT)
        self.req_path_lock = threading.Lock()
        self._api_sock = None
        self.daemon_thread = None
        os.makedirs(SCIOND_API_SOCKDIR, exist_ok=True)
        self.api_addr = (api_addr or os.path.join(
            SCIOND_API_SOCKDIR, "%s.sock" % self.addr.isd_as))

        self.CTRL_PLD_CLASS_MAP = {
            PayloadClass.PATH: {
                PMT.REPLY: self.handle_path_reply,
                PMT.REVOCATION: self.handle_revocation,
            },
            PayloadClass.CERT: {
                CertMgmtType.CERT_CHAIN_REQ: self.process_cert_chain_request,
                CertMgmtType.CERT_CHAIN_REPLY: self.process_cert_chain_reply,
                CertMgmtType.TRC_REPLY: self.process_trc_reply,
                CertMgmtType.TRC_REQ: self.process_trc_request,
            },
        }

        self.SCMP_PLD_CLASS_MAP = {
            SCMPClass.PATH: {
                SCMPPathClass.REVOKED_IF: self.handle_scmp_revocation
            },
        }

        if run_local_api:
            self._api_sock = ReliableSocket(bind=(self.api_addr, "sciond"))
            self._socks.add(self._api_sock, self.handle_accept)

    @classmethod
    def start(cls, conf_dir, addr, api_addr=None, run_local_api=False, port=0):
        """
        Initializes, starts, and returns a SCIONDaemon object.

        Example of usage:
        sd = SCIONDaemon.start(conf_dir, addr)
        paths = sd.get_paths(isd_as)
        """
        inst = cls(conf_dir, addr, api_addr, run_local_api, port)
        name = "SCIONDaemon.run %s" % inst.addr.isd_as
        inst.daemon_thread = threading.Thread(target=thread_safety_net,
                                              args=(inst.run, ),
                                              name=name,
                                              daemon=True)
        inst.daemon_thread.start()
        logging.debug("sciond started with api_addr = %s", inst.api_addr)
        return inst

    def _get_msg_meta(self, packet, addr, sock):
        if sock != self._udp_sock:
            return packet, SockOnlyMetadata.from_values(sock)  # API socket
        else:
            return super()._get_msg_meta(packet, addr, sock)

    def handle_msg_meta(self, msg, meta):
        """
        Main routine to handle incoming SCION messages.
        """
        if isinstance(meta, SockOnlyMetadata):  # From SCIOND API
            try:
                sciond_msg = parse_sciond_msg(msg)
            except SCIONParseError as err:
                logging.error(str(err))
                return
            self.api_handle_request(sciond_msg, meta)
            return
        super().handle_msg_meta(msg, meta)

    def handle_path_reply(self, path_reply, meta):
        """
        Handle path reply from local path server.
        """
        for rev_info in path_reply.iter_rev_infos():
            self.peer_revs.add(rev_info)

        for type_, pcb in path_reply.iter_pcbs():
            seg_meta = PathSegMeta(pcb, self.continue_seg_processing, meta,
                                   type_)
            self.process_path_seg(seg_meta)

    def continue_seg_processing(self, seg_meta):
        """
        For every path segment(that can be verified) received from the path
        server this function gets called to continue the processing for the
        segment.
        The segment is added to pathdb and pending requests are checked.
        """
        pcb = seg_meta.seg
        type_ = seg_meta.type
        map_ = {
            PST.UP: self._handle_up_seg,
            PST.DOWN: self._handle_down_seg,
            PST.CORE: self._handle_core_seg,
        }
        ret = map_[type_](pcb)
        if not ret:
            return
        with self.req_path_lock:
            # .items() makes a copy on an expiring dict, so deleting entries is safe.
            for key, e in self.requested_paths.items():
                if self.path_resolution(*key):
                    e.set()
                    del self.requested_paths[key]

    def _handle_up_seg(self, pcb):
        if self.addr.isd_as != pcb.last_ia():
            return None
        if self.up_segments.update(pcb) == DBResult.ENTRY_ADDED:
            logging.debug("Up segment added: %s", pcb.short_desc())
            return pcb.first_ia()
        return None

    def _handle_down_seg(self, pcb):
        last_ia = pcb.last_ia()
        if self.addr.isd_as == last_ia:
            return None
        if self.down_segments.update(pcb) == DBResult.ENTRY_ADDED:
            logging.debug("Down segment added: %s", pcb.short_desc())
            return last_ia
        return None

    def _handle_core_seg(self, pcb):
        if self.core_segments.update(pcb) == DBResult.ENTRY_ADDED:
            logging.debug("Core segment added: %s", pcb.short_desc())
            return pcb.first_ia()
        return None

    def api_handle_request(self, msg, meta):
        """
        Handle local API's requests.
        """
        if msg.MSG_TYPE == SMT.PATH_REQUEST:
            threading.Thread(target=thread_safety_net,
                             args=(self._api_handle_path_request, msg, meta),
                             daemon=True).start()
        elif msg.MSG_TYPE == SMT.REVOCATION:
            self.handle_revocation(msg.rev_info(), meta)
        elif msg.MSG_TYPE == SMT.AS_REQUEST:
            self._api_handle_as_request(msg, meta)
        elif msg.MSG_TYPE == SMT.IF_REQUEST:
            self._api_handle_if_request(msg, meta)
        elif msg.MSG_TYPE == SMT.SERVICE_REQUEST:
            self._api_handle_service_request(msg, meta)
        else:
            logging.warning("API: type %s not supported.",
                            TypeBase.to_str(msg.MSG_TYPE))

    def _api_handle_path_request(self, request, meta):
        req_id = request.id
        if request.p.flags.sibra:
            logging.warning(
                "Requesting SIBRA paths over SCIOND API not supported yet.")
            self._send_path_reply(req_id, [], SCIONDPathReplyError.INTERNAL,
                                  meta)
            return

        dst_ia = request.dst_ia()
        src_ia = request.src_ia()
        if not src_ia:
            src_ia = self.addr.isd_as
        thread = threading.current_thread()
        thread.name = "SCIONDaemon API id:%s %s -> %s" % (thread.ident, src_ia,
                                                          dst_ia)
        paths, error = self.get_paths(dst_ia, flush=request.p.flags.flush)
        if request.p.maxPaths:
            paths = paths[:request.p.maxPaths]
        logging.debug("Replying to api request for %s with %d paths", dst_ia,
                      len(paths))
        reply_entries = []
        for path_meta in paths:
            fwd_if = path_meta.fwd_path().get_fwd_if()
            # Set dummy host addr if path is empty.
            haddr, port = None, None
            if fwd_if:
                br = self.ifid2br[fwd_if]
                haddr, port = br.addr, br.port
            addrs = [haddr] if haddr else []
            first_hop = HostInfo.from_values(addrs, port)
            reply_entry = SCIONDPathReplyEntry.from_values(
                path_meta, first_hop)
            reply_entries.append(reply_entry)
        self._send_path_reply(req_id, reply_entries, error, meta)

    def _send_path_reply(self, req_id, reply_entries, error, meta):
        path_reply = SCIONDPathReply.from_values(req_id, reply_entries, error)
        self.send_meta(path_reply.pack_full(), meta)

    def _api_handle_as_request(self, request, meta):
        remote_as = request.isd_as()
        if remote_as:
            reply_entry = SCIONDASInfoReplyEntry.from_values(
                remote_as, self.is_core_as(remote_as))
        else:
            reply_entry = SCIONDASInfoReplyEntry.from_values(
                self.addr.isd_as, self.is_core_as(), self.topology.mtu)
        as_reply = SCIONDASInfoReply.from_values(request.id, [reply_entry])
        self.send_meta(as_reply.pack_full(), meta)

    def _api_handle_if_request(self, request, meta):
        all_brs = request.all_brs()
        if_list = []
        if not all_brs:
            if_list = list(request.iter_ids())
        if_entries = []
        for if_id, br in self.ifid2br.items():
            if all_brs or if_id in if_list:
                info = HostInfo.from_values([br.addr], br.port)
                reply_entry = SCIONDIFInfoReplyEntry.from_values(if_id, info)
                if_entries.append(reply_entry)
        if_reply = SCIONDIFInfoReply.from_values(request.id, if_entries)
        self.send_meta(if_reply.pack_full(), meta)

    def _api_handle_service_request(self, request, meta):
        all_svcs = request.all_services()
        svc_list = []
        if not all_svcs:
            svc_list = list(request.iter_service_types())
        svc_entries = []
        for svc_type in ServiceType.all():
            if all_svcs or svc_type in svc_list:
                lookup_res = self.dns_query_topo(svc_type)
                host_infos = []
                for addr, port in lookup_res:
                    host_infos.append(HostInfo.from_values([addr], port))
                reply_entry = SCIONDServiceInfoReplyEntry.from_values(
                    svc_type, host_infos)
                svc_entries.append(reply_entry)
        svc_reply = SCIONDServiceInfoReply.from_values(request.id, svc_entries)
        self.send_meta(svc_reply.pack_full(), meta)

    def handle_scmp_revocation(self, pld, meta):
        rev_info = RevocationInfo.from_raw(pld.info.rev_info)
        self.handle_revocation(rev_info, meta)

    def handle_revocation(self, rev_info, meta):
        assert isinstance(rev_info, RevocationInfo)
        if not self._validate_revocation(rev_info):
            return
        # Go through all segment databases and remove affected segments.
        removed_up = self._remove_revoked_pcbs(self.up_segments, rev_info)
        removed_core = self._remove_revoked_pcbs(self.core_segments, rev_info)
        removed_down = self._remove_revoked_pcbs(self.down_segments, rev_info)
        logging.info("Removed %d UP- %d CORE- and %d DOWN-Segments." %
                     (removed_up, removed_core, removed_down))

    def _remove_revoked_pcbs(self, db, rev_info):
        """
        Removes all segments from 'db' that contain an IF token for which
        rev_token is a preimage (within 20 calls).

        :param db: The PathSegmentDB.
        :type db: :class:`lib.path_db.PathSegmentDB`
        :param rev_info: The revocation info
        :type rev_info: RevocationInfo

        :returns: The number of deletions.
        :rtype: int
        """

        if not ConnectedHashTree.verify_epoch(rev_info.p.epoch):
            logging.debug(
                "Failed to verify epoch: rev_info epoch %d,current epoch %d." %
                (rev_info.p.epoch, ConnectedHashTree.get_current_epoch()))
            return 0

        to_remove = []
        for segment in db(full=True):
            for asm in segment.iter_asms():
                if self._verify_revocation_for_asm(rev_info, asm):
                    logging.debug("Removing segment: %s" %
                                  segment.short_desc())
                    to_remove.append(segment.get_hops_hash())
        return db.delete_all(to_remove)

    def _flush_path_dbs(self):
        self.core_segments.flush()
        self.down_segments.flush()
        self.up_segments.flush()

    def get_paths(self, dst_ia, flags=(), flush=False):
        """Return a list of paths."""
        logging.debug("Paths requested for ISDAS=%s, flags=%s, flush=%s",
                      dst_ia, flags, flush)
        if flush:
            logging.info("Flushing PathDBs.")
            self._flush_path_dbs()
        if self.addr.isd_as == dst_ia or (self.addr.isd_as.any_as() == dst_ia
                                          and self.topology.is_core_as):
            # Either the destination is the local AS, or the destination is any
            # core AS in this ISD, and the local AS is in the core
            empty = SCIONPath()
            empty_meta = FwdPathMeta.from_values(empty, [], self.topology.mtu)
            return [empty_meta], SCIONDPathReplyError.OK
        paths = self.path_resolution(dst_ia, flags=flags)
        if not paths:
            key = dst_ia, flags
            with self.req_path_lock:
                if key not in self.requested_paths:
                    # No previous outstanding request
                    self.requested_paths[key] = threading.Event()
                    self._fetch_segments(key)
                e = self.requested_paths[key]
            if not e.wait(self.PATH_REQ_TOUT):
                logging.error("Query timed out for %s", dst_ia)
                return [], SCIONDPathReplyError.PS_TIMEOUT
            paths = self.path_resolution(dst_ia, flags=flags)
        error_code = (SCIONDPathReplyError.OK
                      if paths else SCIONDPathReplyError.NO_PATHS)
        return paths, error_code

    def path_resolution(self, dst_ia, flags=()):
        # dst as == 0 means any core AS in the specified ISD.
        dst_is_core = self.is_core_as(dst_ia) or dst_ia[1] == 0
        sibra = PATH_FLAG_SIBRA in flags
        if self.topology.is_core_as:
            if dst_is_core:
                ret = self._resolve_core_core(dst_ia, sibra=sibra)
            else:
                ret = self._resolve_core_not_core(dst_ia, sibra=sibra)
        elif dst_is_core:
            ret = self._resolve_not_core_core(dst_ia, sibra=sibra)
        elif sibra:
            ret = self._resolve_not_core_not_core_sibra(dst_ia)
        else:
            ret = self._resolve_not_core_not_core_scion(dst_ia)
        if not sibra:
            return ret
        # FIXME(kormat): Strip off PCBs, and just return sibra reservation
        # blocks
        return self._sibra_strip_pcbs(self._strip_nones(ret))

    def _resolve_core_core(self, dst_ia, sibra=False):
        """Resolve path from core to core."""
        res = set()
        for cseg in self.core_segments(last_ia=self.addr.isd_as,
                                       sibra=sibra,
                                       **dst_ia.params()):
            res.add((None, cseg, None))
        if sibra:
            return res
        return tuples_to_full_paths(res)

    def _resolve_core_not_core(self, dst_ia, sibra=False):
        """Resolve path from core to non-core."""
        res = set()
        # First check whether there is a direct path.
        for dseg in self.down_segments(first_ia=self.addr.isd_as,
                                       last_ia=dst_ia,
                                       sibra=sibra):
            res.add((None, None, dseg))
        # Check core-down combination.
        for dseg in self.down_segments(last_ia=dst_ia, sibra=sibra):
            dseg_ia = dseg.first_ia()
            if self.addr.isd_as == dseg_ia:
                pass
            for cseg in self.core_segments(first_ia=dseg_ia,
                                           last_ia=self.addr.isd_as,
                                           sibra=sibra):
                res.add((None, cseg, dseg))
        if sibra:
            return res
        return tuples_to_full_paths(res)

    def _resolve_not_core_core(self, dst_ia, sibra=False):
        """Resolve path from non-core to core."""
        res = set()
        params = dst_ia.params()
        params["sibra"] = sibra
        if dst_ia[0] == self.addr.isd_as[0]:
            # Dst in local ISD. First check whether DST is a (super)-parent.
            for useg in self.up_segments(**params):
                res.add((useg, None, None))
        # Check whether dst is known core AS.
        for cseg in self.core_segments(**params):
            # Check do we have an up-seg that is connected to core_seg.
            for useg in self.up_segments(first_ia=cseg.last_ia(), sibra=sibra):
                res.add((useg, cseg, None))
        if sibra:
            return res
        return tuples_to_full_paths(res)

    def _resolve_not_core_not_core_scion(self, dst_ia):
        """Resolve SCION path from non-core to non-core."""
        up_segs = self.up_segments()
        down_segs = self.down_segments(last_ia=dst_ia)
        core_segs = self._calc_core_segs(dst_ia[0], up_segs, down_segs)
        full_paths = build_shortcut_paths(up_segs, down_segs, self.peer_revs)
        tuples = []
        for up_seg in up_segs:
            for down_seg in down_segs:
                tuples.append((up_seg, None, down_seg))
                for core_seg in core_segs:
                    tuples.append((up_seg, core_seg, down_seg))
        full_paths.extend(tuples_to_full_paths(tuples))
        return full_paths

    def _resolve_not_core_not_core_sibra(self, dst_ia):
        """Resolve SIBRA path from non-core to non-core."""
        res = set()
        up_segs = set(self.up_segments(sibra=True))
        down_segs = set(self.down_segments(last_ia=dst_ia, sibra=True))
        for up_seg, down_seg in product(up_segs, down_segs):
            src_core_ia = up_seg.first_ia()
            dst_core_ia = down_seg.first_ia()
            if src_core_ia == dst_core_ia:
                res.add((up_seg, down_seg))
                continue
            for core_seg in self.core_segments(first_ia=dst_core_ia,
                                               last_ia=src_core_ia,
                                               sibra=True):
                res.add((up_seg, core_seg, down_seg))
        return res

    def _strip_nones(self, set_):
        """Strip None entries from a set of tuples"""
        res = []
        for tup in set_:
            res.append(tuple(filter(None, tup)))
        return res

    def _sibra_strip_pcbs(self, paths):
        ret = []
        for pcbs in paths:
            resvs = []
            for pcb in pcbs:
                resvs.append(self._sibra_strip_pcb(pcb))
            ret.append(resvs)
        return ret

    def _sibra_strip_pcb(self, pcb):
        assert pcb.is_sibra()
        pcb_ext = pcb.sibra_ext
        resv_info = pcb_ext.info
        resv = ResvBlockSteady.from_values(resv_info, pcb.get_n_hops())
        asms = pcb.iter_asms()
        if pcb_ext.p.up:
            asms = reversed(list(asms))
        iflist = []
        for sof, asm in zip(pcb_ext.iter_sofs(), asms):
            resv.sofs.append(sof)
            iflist.extend(
                self._sibra_add_ifs(asm.isd_as(), sof, resv_info.fwd_dir))
        assert resv.num_hops == len(resv.sofs)
        return pcb_ext.p.id, resv, iflist

    def _sibra_add_ifs(self, isd_as, sof, fwd):
        def _add(ifid):
            if ifid:
                ret.append((isd_as, ifid))

        ret = []
        if fwd:
            _add(sof.ingress)
            _add(sof.egress)
        else:
            _add(sof.egress)
            _add(sof.ingress)
        return ret

    def _wait_for_events(self, events, deadline):
        """
        Wait on a set of events, but only until the specified deadline. Returns
        the number of events that happened while waiting.
        """
        count = 0
        for e in events:
            if e.wait(max(0, deadline - SCIONTime.get_time())):
                count += 1
        return count

    def _fetch_segments(self, key):
        """
        Called to fetch the requested path.
        """
        dst_ia, flags = key
        try:
            addr, port = self.dns_query_topo(PATH_SERVICE)[0]
        except SCIONServiceLookupError:
            log_exception("Error querying path service:")
            return
        req = PathSegmentReq.from_values(self.addr.isd_as, dst_ia, flags=flags)
        logging.debug("Sending path request: %s", req.short_desc())
        meta = self.DefaultMeta.from_values(host=addr, port=port)
        self.send_meta(req, meta)

    def _calc_core_segs(self, dst_isd, up_segs, down_segs):
        """
        Calculate all possible core segments joining the provided up and down
        segments. Returns a list of all known segments, and a seperate list of
        the missing AS pairs.
        """
        src_core_ases = set()
        dst_core_ases = set()
        for seg in up_segs:
            src_core_ases.add(seg.first_ia()[1])
        for seg in down_segs:
            dst_core_ases.add(seg.first_ia()[1])
        # Generate all possible AS pairs
        as_pairs = list(product(src_core_ases, dst_core_ases))
        return self._find_core_segs(self.addr.isd_as[0], dst_isd, as_pairs)

    def _find_core_segs(self, src_isd, dst_isd, as_pairs):
        """
        Given a set of AS pairs across 2 ISDs, return the core segments
        connecting those pairs
        """
        core_segs = []
        for src_core_as, dst_core_as in as_pairs:
            src_ia = ISD_AS.from_values(src_isd, src_core_as)
            dst_ia = ISD_AS.from_values(dst_isd, dst_core_as)
            if src_ia == dst_ia:
                continue
            seg = self.core_segments(first_ia=dst_ia, last_ia=src_ia)
            if seg:
                core_segs.extend(seg)
        return core_segs
Ejemplo n.º 23
0
class SCIONDaemon(SCIONElement):
    """
    The SCION Daemon used for retrieving and combining paths.
    """
    MAX_REQS = 1024
    # Time a path segment is cached at a host (in seconds).
    SEGMENT_TTL = 300
    # Empty Path TTL
    EMPTY_PATH_TTL = SEGMENT_TTL

    def __init__(self, conf_dir, addr, api_addr, run_local_api=False,
                 port=None, spki_cache_dir=GEN_CACHE_PATH, prom_export=None, delete_sock=False):
        """
        Initialize an instance of the class SCIONDaemon.
        """
        super().__init__("sciond", conf_dir, spki_cache_dir=spki_cache_dir,
                         prom_export=prom_export, public=[(addr, port)])
        up_labels = {**self._labels, "type": "up"} if self._labels else None
        down_labels = {**self._labels, "type": "down"} if self._labels else None
        core_labels = {**self._labels, "type": "core"} if self._labels else None
        self.up_segments = PathSegmentDB(segment_ttl=self.SEGMENT_TTL, labels=up_labels)
        self.down_segments = PathSegmentDB(segment_ttl=self.SEGMENT_TTL, labels=down_labels)
        self.core_segments = PathSegmentDB(segment_ttl=self.SEGMENT_TTL, labels=core_labels)
        self.rev_cache = RevCache()
        # Keep track of requested paths.
        self.requested_paths = ExpiringDict(self.MAX_REQS, PATH_REQ_TOUT)
        self.req_path_lock = threading.Lock()
        self._api_sock = None
        self.daemon_thread = None
        os.makedirs(SCIOND_API_SOCKDIR, exist_ok=True)
        self.api_addr = (api_addr or get_default_sciond_path())
        if delete_sock:
            try:
                os.remove(self.api_addr)
            except OSError as e:
                if e.errno != errno.ENOENT:
                    logging.error("Could not delete socket %s: %s" % (self.api_addr, e))

        self.CTRL_PLD_CLASS_MAP = {
            PayloadClass.PATH: {
                PMT.REPLY: self.handle_path_reply,
                PMT.REVOCATION: self.handle_revocation,
            },
            PayloadClass.CERT: {
                CertMgmtType.CERT_CHAIN_REQ: self.process_cert_chain_request,
                CertMgmtType.CERT_CHAIN_REPLY: self.process_cert_chain_reply,
                CertMgmtType.TRC_REPLY: self.process_trc_reply,
                CertMgmtType.TRC_REQ: self.process_trc_request,
            },
        }

        self.SCMP_PLD_CLASS_MAP = {
            SCMPClass.PATH:
                {SCMPPathClass.REVOKED_IF: self.handle_scmp_revocation},
        }

        if run_local_api:
            self._api_sock = ReliableSocket(bind_unix=(self.api_addr, "sciond"))
            self._socks.add(self._api_sock, self.handle_accept)

    @classmethod
    def start(cls, conf_dir, addr, api_addr=None, run_local_api=False, port=0):
        """
        Initializes and starts a SCIOND instance.
        """
        inst = cls(conf_dir, addr, api_addr, run_local_api, port)
        name = "SCIONDaemon.run %s" % inst.addr.isd_as
        inst.daemon_thread = threading.Thread(
            target=thread_safety_net, args=(inst.run,), name=name, daemon=True)
        inst.daemon_thread.start()
        logging.debug("sciond started with api_addr = %s", inst.api_addr)

    def _get_msg_meta(self, packet, addr, sock):
        if sock != self._udp_sock:
            return packet, SockOnlyMetadata.from_values(sock)  # API socket
        else:
            return super()._get_msg_meta(packet, addr, sock)

    def handle_msg_meta(self, msg, meta):
        """
        Main routine to handle incoming SCION messages.
        """
        if isinstance(meta, SockOnlyMetadata):  # From SCIOND API
            try:
                sciond_msg = SCIONDMsg.from_raw(msg)
            except SCIONParseError as err:
                logging.error(str(err))
                return
            self.api_handle_request(sciond_msg, meta)
            return
        super().handle_msg_meta(msg, meta)

    def handle_path_reply(self, cpld, meta):
        """
        Handle path reply from local path server.
        """
        pmgt = cpld.union
        path_reply = pmgt.union
        assert isinstance(path_reply, PathSegmentReply), type(path_reply)
        recs = path_reply.recs()
        for srev_info in recs.iter_srev_infos():
            self.check_revocation(srev_info, lambda x: self.continue_revocation_processing(
                                  srev_info) if not x else False, meta)

        req = path_reply.req()
        key = req.dst_ia(), req.flags()
        with self.req_path_lock:
            r = self.requested_paths.get(key)
            if r:
                r.notify_reply(path_reply)
            else:
                logging.warning("No outstanding request found for %s", key)
        for type_, pcb in recs.iter_pcbs():
            seg_meta = PathSegMeta(pcb, self.continue_seg_processing,
                                   meta, type_, params=(r,))
            self._process_path_seg(seg_meta, cpld.req_id)

    def continue_revocation_processing(self, srev_info):
        self.rev_cache.add(srev_info)
        self.remove_revoked_segments(srev_info.rev_info())

    def continue_seg_processing(self, seg_meta):
        """
        For every path segment(that can be verified) received from the path
        server this function gets called to continue the processing for the
        segment.
        The segment is added to pathdb and pending requests are checked.
        """
        pcb = seg_meta.seg
        type_ = seg_meta.type
        # Check that segment does not contain a revoked interface.
        if not self.check_revoked_interface(pcb, self.rev_cache):
            return
        map_ = {
            PST.UP: self._handle_up_seg,
            PST.DOWN: self._handle_down_seg,
            PST.CORE: self._handle_core_seg,
        }
        map_[type_](pcb)
        r = seg_meta.params[0]
        if r:
            r.verified_segment()

    def _handle_up_seg(self, pcb):
        if self.addr.isd_as != pcb.last_ia():
            return None
        if self.up_segments.update(pcb) == DBResult.ENTRY_ADDED:
            logging.debug("Up segment added: %s", pcb.short_desc())
            return pcb.first_ia()
        return None

    def _handle_down_seg(self, pcb):
        last_ia = pcb.last_ia()
        if self.addr.isd_as == last_ia:
            return None
        if self.down_segments.update(pcb) == DBResult.ENTRY_ADDED:
            logging.debug("Down segment added: %s", pcb.short_desc())
            return last_ia
        return None

    def _handle_core_seg(self, pcb):
        if self.core_segments.update(pcb) == DBResult.ENTRY_ADDED:
            logging.debug("Core segment added: %s", pcb.short_desc())
            return pcb.first_ia()
        return None

    def api_handle_request(self, msg, meta):
        """
        Handle local API's requests.
        """
        mtype = msg.type()
        if mtype == SMT.PATH_REQUEST:
            threading.Thread(
                target=thread_safety_net,
                args=(self._api_handle_path_request, msg, meta),
                daemon=True).start()
        elif mtype == SMT.REVOCATION:
            self._api_handle_rev_notification(msg, meta)
        elif mtype == SMT.AS_REQUEST:
            self._api_handle_as_request(msg, meta)
        elif mtype == SMT.IF_REQUEST:
            self._api_handle_if_request(msg, meta)
        elif mtype == SMT.SERVICE_REQUEST:
            self._api_handle_service_request(msg, meta)
        elif mtype == SMT.SEGTYPEHOP_REQUEST:
            self._api_handle_seg_type_request(msg, meta)
        else:
            logging.warning(
                "API: type %s not supported.", TypeBase.to_str(mtype))

    def _api_handle_path_request(self, pld, meta):
        request = pld.union
        assert isinstance(request, SCIONDPathRequest), type(request)
        req_id = pld.id

        dst_ia = request.dst_ia()
        src_ia = request.src_ia()
        if not src_ia:
            src_ia = self.addr.isd_as
        thread = threading.current_thread()
        thread.name = "SCIONDaemon API id:%s %s -> %s" % (
            thread.ident, src_ia, dst_ia)
        paths, error = self.get_paths(dst_ia, flush=request.p.flags.refresh)
        if request.p.maxPaths:
            paths = paths[:request.p.maxPaths]

        reply_entries = []
        for path_meta in paths:
            fwd_if = path_meta.fwd_path().get_fwd_if()
            # Set dummy host addr if path is empty.
            haddr, port = None, None
            if fwd_if:
                br = self.ifid2br[fwd_if]
                haddr, port = br.int_addrs.public[0]
            addrs = [haddr] if haddr else []
            first_hop = HostInfo.from_values(addrs, port)
            reply_entry = SCIONDPathReplyEntry.from_values(
                path_meta, first_hop)
            reply_entries.append(reply_entry)
        logging.debug("Replying to api request for %s with %d paths:\n%s",
                      dst_ia, len(paths), "\n".join([p.short_desc() for p in paths]))
        self._send_path_reply(req_id, reply_entries, error, meta)

    def _send_path_reply(self, req_id, reply_entries, error, meta):
        path_reply = SCIONDMsg(SCIONDPathReply.from_values(reply_entries, error), req_id)
        self.send_meta(path_reply.pack(), meta)

    def _api_handle_as_request(self, pld, meta):
        request = pld.union
        assert isinstance(request, SCIONDASInfoRequest), type(request)
        req_ia = request.isd_as()
        if not req_ia or req_ia.is_zero() or req_ia == self.addr.isd_as:
            # Request is for the local AS.
            reply_entry = SCIONDASInfoReplyEntry.from_values(
                self.addr.isd_as, self.is_core_as(), self.topology.mtu)
        else:
            # Request is for a remote AS.
            reply_entry = SCIONDASInfoReplyEntry.from_values(req_ia, self.is_core_as(req_ia))
        as_reply = SCIONDMsg(SCIONDASInfoReply.from_values([reply_entry]), pld.id)
        self.send_meta(as_reply.pack(), meta)

    def _api_handle_if_request(self, pld, meta):
        request = pld.union
        assert isinstance(request, SCIONDIFInfoRequest), type(request)
        all_brs = request.all_brs()
        if_list = []
        if not all_brs:
            if_list = list(request.iter_ids())
        if_entries = []
        for if_id, br in self.ifid2br.items():
            if all_brs or if_id in if_list:
                br_addr, br_port = br.int_addrs.public[0]
                info = HostInfo.from_values([br_addr], br_port)
                reply_entry = SCIONDIFInfoReplyEntry.from_values(if_id, info)
                if_entries.append(reply_entry)
        if_reply = SCIONDMsg(SCIONDIFInfoReply.from_values(if_entries), pld.id)
        self.send_meta(if_reply.pack(), meta)

    def _api_handle_service_request(self, pld, meta):
        request = pld.union
        assert isinstance(request, SCIONDServiceInfoRequest), type(request)
        all_svcs = request.all_services()
        svc_list = []
        if not all_svcs:
            svc_list = list(request.iter_service_types())
        svc_entries = []
        for svc_type in ServiceType.all():
            if all_svcs or svc_type in svc_list:
                lookup_res = self.dns_query_topo(svc_type)
                host_infos = []
                for addr, port in lookup_res:
                    host_infos.append(HostInfo.from_values([addr], port))
                reply_entry = SCIONDServiceInfoReplyEntry.from_values(
                    svc_type, host_infos)
                svc_entries.append(reply_entry)
        svc_reply = SCIONDMsg(SCIONDServiceInfoReply.from_values(svc_entries), pld.id)
        self.send_meta(svc_reply.pack(), meta)

    def _api_handle_rev_notification(self, pld, meta):
        request = pld.union
        assert isinstance(request, SCIONDRevNotification), type(request)
        self.handle_revocation(CtrlPayload(PathMgmt(request.srev_info())), meta, pld)

    def _api_handle_seg_type_request(self, pld, meta):
        request = pld.union
        assert isinstance(request, SCIONDSegTypeHopRequest), type(request)
        segmentType = request.p.type
        db = []
        if segmentType == PST.CORE:
            db = self.core_segments
        elif segmentType == PST.UP:
            db = self.up_segments
        elif segmentType == PST.DOWN:
            db = self.down_segments
        else:
            logging.error("Requesting segment type %s unrecognized.", segmentType)

        seg_entries = []
        for segment in db(full=True):
            if_list = []
            for asm in segment.iter_asms():
                isd_as = asm.isd_as()
                hof = asm.pcbm(0).hof()
                egress = hof.egress_if
                ingress = hof.ingress_if
                if ingress:
                    if_list.append(PathInterface.from_values(isd_as, ingress))
                if egress:
                    if_list.append(PathInterface.from_values(isd_as, egress))
            reply_entry = SCIONDSegTypeHopReplyEntry.from_values(
                if_list, segment.get_timestamp(), segment.get_expiration_time())
            seg_entries.append(reply_entry)
        seg_reply = SCIONDMsg(
            SCIONDSegTypeHopReply.from_values(seg_entries), pld.id)
        self.send_meta(seg_reply.pack(), meta)

    def handle_scmp_revocation(self, pld, meta):
        srev_info = SignedRevInfo.from_raw(pld.info.srev_info)
        self.handle_revocation(CtrlPayload(PathMgmt(srev_info)), meta)

    def handle_revocation(self, cpld, meta, pld=None):
        pmgt = cpld.union
        srev_info = pmgt.union
        rev_info = srev_info.rev_info()
        assert isinstance(rev_info, RevocationInfo), type(rev_info)
        logging.debug("Received revocation: %s from %s", srev_info.short_desc(), meta)
        self.check_revocation(srev_info,
                              lambda e: self.process_revocation(e, srev_info, meta, pld), meta)

    def process_revocation(self, error, srev_info, meta, pld):
        rev_info = srev_info.rev_info()
        status = None
        if error is None:
            status = SCIONDRevReplyStatus.VALID
            self.rev_cache.add(srev_info)
            self.remove_revoked_segments(rev_info)
        else:
            if type(error) == RevInfoValidationError:
                logging.error("Failed to validate RevInfo %s from %s: %s",
                              srev_info.short_desc(), meta, error)
                status = SCIONDRevReplyStatus.INVALID
            if type(error) == RevInfoExpiredError:
                logging.info("Ignoring expired Revinfo, %s from %s", srev_info.short_desc(), meta)
                status = SCIONDRevReplyStatus.STALE
            if type(error) == SignedRevInfoCertFetchError:
                logging.error("Failed to fetch certificate for SignedRevInfo %s from %s: %s",
                              srev_info.short_desc(), meta, error)
                status = SCIONDRevReplyStatus.UNKNOWN
            if type(error) == SignedRevInfoVerificationError:
                logging.error("Failed to verify SRevInfo %s from %s: %s",
                              srev_info.short_desc(), meta, error)
                status = SCIONDRevReplyStatus.SIGFAIL
            if type(error) == SCIONBaseError:
                logging.error("Revocation check failed for %s from %s:\n%s",
                              srev_info.short_desc(), meta, error)
                status = SCIONDRevReplyStatus.UNKNOWN

        if pld:
            rev_reply = SCIONDMsg(SCIONDRevReply.from_values(status), pld.id)
            self.send_meta(rev_reply.pack(), meta)

    def remove_revoked_segments(self, rev_info):
        # Go through all segment databases and remove affected segments.
        removed_up = removed_core = removed_down = 0
        if rev_info.p.linkType == LinkType.CORE:
            removed_core = self._remove_revoked_pcbs(self.core_segments, rev_info)
        elif rev_info.p.linkType in [LinkType.PARENT, LinkType.CHILD]:
            removed_up = self._remove_revoked_pcbs(self.up_segments, rev_info)
            removed_down = self._remove_revoked_pcbs(self.down_segments, rev_info)
        elif rev_info.p.linkType != LinkType.PEER:
            logging.error("Bad RevInfo link type: %s", rev_info.p.linkType)

        logging.info("Removed %d UP- %d CORE- and %d DOWN-Segments." %
                     (removed_up, removed_core, removed_down))

    def _remove_revoked_pcbs(self, db, rev_info):
        """
        Removes all segments from 'db' that have a revoked upstream PCBMarking.

        :param db: The PathSegmentDB.
        :type db: :class:`lib.path_db.PathSegmentDB`
        :param rev_info: The revocation info
        :type rev_info: RevocationInfo

        :returns: The number of deletions.
        :rtype: int
        """

        to_remove = []
        for segment in db(full=True):
            for asm in segment.iter_asms():
                if self._check_revocation_for_asm(rev_info, asm, verify_all=False):
                    logging.debug("Removing segment: %s" % segment.short_desc())
                    to_remove.append(segment.get_hops_hash())
        return db.delete_all(to_remove)

    def _flush_path_dbs(self):
        self.core_segments.flush()
        self.down_segments.flush()
        self.up_segments.flush()

    def get_paths(self, dst_ia, flags=(), flush=False):
        """Return a list of paths."""
        logging.debug("Paths requested for ISDAS=%s, flags=%s, flush=%s",
                      dst_ia, flags, flush)
        if flush:
            logging.info("Flushing PathDBs.")
            self._flush_path_dbs()
        if self.addr.isd_as == dst_ia or (
                self.addr.isd_as.any_as() == dst_ia and
                self.topology.is_core_as):
            # Either the destination is the local AS, or the destination is any
            # core AS in this ISD, and the local AS is in the core
            empty = SCIONPath()
            exp_time = int(time.time()) + self.EMPTY_PATH_TTL
            empty_meta = FwdPathMeta.from_values(empty, [], self.topology.mtu, exp_time)
            return [empty_meta], SCIONDPathReplyError.OK
        paths = self.path_resolution(dst_ia, flags=flags)
        if not paths:
            key = dst_ia, flags
            with self.req_path_lock:
                r = self.requested_paths.get(key)
                if r is None:
                    # No previous outstanding request
                    req = PathSegmentReq.from_values(self.addr.isd_as, dst_ia, flags=flags)
                    r = RequestState(req.copy())
                    self.requested_paths[key] = r
                    self._fetch_segments(req)
            # Wait until event gets set.
            timeout = not r.e.wait(PATH_REQ_TOUT)
            with self.req_path_lock:
                if timeout:
                    r.done()
                if key in self.requested_paths:
                    del self.requested_paths[key]
            if timeout:
                logging.error("Query timed out for %s", dst_ia)
                return [], SCIONDPathReplyError.PS_TIMEOUT
            # Check if we can fulfill the path request.
            paths = self.path_resolution(dst_ia, flags=flags)
            if not paths:
                logging.error("No paths found for %s", dst_ia)
                return [], SCIONDPathReplyError.NO_PATHS
        return paths, SCIONDPathReplyError.OK

    def path_resolution(self, dst_ia, flags=()):
        # dst as == 0 means any core AS in the specified ISD.
        dst_is_core = self.is_core_as(dst_ia) or dst_ia[1] == 0
        sibra = PATH_FLAG_SIBRA in flags
        if self.topology.is_core_as:
            if dst_is_core:
                ret = self._resolve_core_core(dst_ia, sibra=sibra)
            else:
                ret = self._resolve_core_not_core(dst_ia, sibra=sibra)
        elif dst_is_core:
            ret = self._resolve_not_core_core(dst_ia, sibra=sibra)
        elif sibra:
            ret = self._resolve_not_core_not_core_sibra(dst_ia)
        else:
            ret = self._resolve_not_core_not_core_scion(dst_ia)
        if not sibra:
            return ret
        # FIXME(kormat): Strip off PCBs, and just return sibra reservation
        # blocks
        return self._sibra_strip_pcbs(self._strip_nones(ret))

    def _resolve_core_core(self, dst_ia, sibra=False):
        """Resolve path from core to core."""
        res = set()
        for cseg in self.core_segments(last_ia=self.addr.isd_as, sibra=sibra,
                                       **dst_ia.params()):
            res.add((None, cseg, None))
        if sibra:
            return res
        return tuples_to_full_paths(res)

    def _resolve_core_not_core(self, dst_ia, sibra=False):
        """Resolve path from core to non-core."""
        res = set()
        # First check whether there is a direct path.
        for dseg in self.down_segments(
                first_ia=self.addr.isd_as, last_ia=dst_ia, sibra=sibra):
            res.add((None, None, dseg))
        # Check core-down combination.
        for dseg in self.down_segments(last_ia=dst_ia, sibra=sibra):
            dseg_ia = dseg.first_ia()
            if self.addr.isd_as == dseg_ia:
                pass
            for cseg in self.core_segments(
                    first_ia=dseg_ia, last_ia=self.addr.isd_as, sibra=sibra):
                res.add((None, cseg, dseg))
        if sibra:
            return res
        return tuples_to_full_paths(res)

    def _resolve_not_core_core(self, dst_ia, sibra=False):
        """Resolve path from non-core to core."""
        res = set()
        params = dst_ia.params()
        params["sibra"] = sibra
        if dst_ia[0] == self.addr.isd_as[0]:
            # Dst in local ISD. First check whether DST is a (super)-parent.
            for useg in self.up_segments(**params):
                res.add((useg, None, None))
        # Check whether dst is known core AS.
        for cseg in self.core_segments(**params):
            # Check do we have an up-seg that is connected to core_seg.
            for useg in self.up_segments(first_ia=cseg.last_ia(), sibra=sibra):
                res.add((useg, cseg, None))
        if sibra:
            return res
        return tuples_to_full_paths(res)

    def _resolve_not_core_not_core_scion(self, dst_ia):
        """Resolve SCION path from non-core to non-core."""
        up_segs = self.up_segments()
        down_segs = self.down_segments(last_ia=dst_ia)
        core_segs = self._calc_core_segs(dst_ia[0], up_segs, down_segs)
        full_paths = build_shortcut_paths(
            up_segs, down_segs, self.rev_cache)
        tuples = []
        for up_seg in up_segs:
            for down_seg in down_segs:
                tuples.append((up_seg, None, down_seg))
                for core_seg in core_segs:
                    tuples.append((up_seg, core_seg, down_seg))
        full_paths.extend(tuples_to_full_paths(tuples))
        return full_paths

    def _resolve_not_core_not_core_sibra(self, dst_ia):
        """Resolve SIBRA path from non-core to non-core."""
        res = set()
        up_segs = set(self.up_segments(sibra=True))
        down_segs = set(self.down_segments(last_ia=dst_ia, sibra=True))
        for up_seg, down_seg in product(up_segs, down_segs):
            src_core_ia = up_seg.first_ia()
            dst_core_ia = down_seg.first_ia()
            if src_core_ia == dst_core_ia:
                res.add((up_seg, down_seg))
                continue
            for core_seg in self.core_segments(first_ia=dst_core_ia,
                                               last_ia=src_core_ia, sibra=True):
                res.add((up_seg, core_seg, down_seg))
        return res

    def _strip_nones(self, set_):
        """Strip None entries from a set of tuples"""
        res = []
        for tup in set_:
            res.append(tuple(filter(None, tup)))
        return res

    def _sibra_strip_pcbs(self, paths):
        ret = []
        for pcbs in paths:
            resvs = []
            for pcb in pcbs:
                resvs.append(self._sibra_strip_pcb(pcb))
            ret.append(resvs)
        return ret

    def _sibra_strip_pcb(self, pcb):
        assert pcb.is_sibra()
        pcb_ext = pcb.sibra_ext
        resv_info = pcb_ext.info
        resv = ResvBlockSteady.from_values(resv_info, pcb.get_n_hops())
        asms = pcb.iter_asms()
        if pcb_ext.p.up:
            asms = reversed(list(asms))
        iflist = []
        for sof, asm in zip(pcb_ext.iter_sofs(), asms):
            resv.sofs.append(sof)
            iflist.extend(self._sibra_add_ifs(
                asm.isd_as(), sof, resv_info.fwd_dir))
        assert resv.num_hops == len(resv.sofs)
        return pcb_ext.p.id, resv, iflist

    def _sibra_add_ifs(self, isd_as, sof, fwd):
        def _add(ifid):
            if ifid:
                ret.append((isd_as, ifid))
        ret = []
        if fwd:
            _add(sof.ingress)
            _add(sof.egress)
        else:
            _add(sof.egress)
            _add(sof.ingress)
        return ret

    def _wait_for_events(self, events, deadline):
        """
        Wait on a set of events, but only until the specified deadline. Returns
        the number of events that happened while waiting.
        """
        count = 0
        for e in events:
            if e.wait(max(0, deadline - SCIONTime.get_time())):
                count += 1
        return count

    def _fetch_segments(self, req):
        """
        Called to fetch the requested path.
        """
        try:
            addr, port = self.dns_query_topo(ServiceType.PS)[0]
        except SCIONServiceLookupError:
            log_exception("Error querying path service:")
            return
        req_id = mk_ctrl_req_id()
        logging.debug("Sending path request (%s) to [%s]:%s [id: %016x]",
                      req.short_desc(), addr, port, req_id)
        meta = self._build_meta(host=addr, port=port)
        self.send_meta(CtrlPayload(PathMgmt(req), req_id=req_id), meta)

    def _calc_core_segs(self, dst_isd, up_segs, down_segs):
        """
        Calculate all possible core segments joining the provided up and down
        segments. Returns a list of all known segments, and a seperate list of
        the missing AS pairs.
        """
        src_core_ases = set()
        dst_core_ases = set()
        for seg in up_segs:
            src_core_ases.add(seg.first_ia()[1])
        for seg in down_segs:
            dst_core_ases.add(seg.first_ia()[1])
        # Generate all possible AS pairs
        as_pairs = list(product(src_core_ases, dst_core_ases))
        return self._find_core_segs(self.addr.isd_as[0], dst_isd, as_pairs)

    def _find_core_segs(self, src_isd, dst_isd, as_pairs):
        """
        Given a set of AS pairs across 2 ISDs, return the core segments
        connecting those pairs
        """
        core_segs = []
        for src_core_as, dst_core_as in as_pairs:
            src_ia = ISD_AS.from_values(src_isd, src_core_as)
            dst_ia = ISD_AS.from_values(dst_isd, dst_core_as)
            if src_ia == dst_ia:
                continue
            seg = self.core_segments(first_ia=dst_ia, last_ia=src_ia)
            if seg:
                core_segs.extend(seg)
        return core_segs

    def run(self):
        """
        Run an instance of the SCION daemon.
        """
        threading.Thread(
            target=thread_safety_net, args=(self._check_trc_cert_reqs,),
            name="Elem.check_trc_cert_reqs", daemon=True).start()
        super().run()
Ejemplo n.º 24
0
class PathServer(SCIONElement, metaclass=ABCMeta):
    """
    The SCION Path Server.
    """
    SERVICE_TYPE = PATH_SERVICE
    MAX_SEG_NO = 5  # TODO: replace by config variable.
    # ZK path for incoming PATHs
    ZK_PATH_CACHE_PATH = "path_cache"
    # Number of tokens the PS checks when receiving a revocation.
    N_TOKENS_CHECK = 20
    # Max number of segments per propagation packet
    PROP_LIMIT = 5
    # Max number of segments per ZK cache entry
    ZK_SHARE_LIMIT = 10

    def __init__(self, server_id, conf_dir):
        """
        :param str server_id: server identifier.
        :param str conf_dir: configuration directory.
        """
        super().__init__(server_id, conf_dir)
        self.down_segments = PathSegmentDB(max_res_no=self.MAX_SEG_NO)
        self.core_segments = PathSegmentDB(max_res_no=self.MAX_SEG_NO)
        self.pending_req = defaultdict(list)  # Dict of pending requests.
        # Used when l/cPS doesn't have up/dw-path.
        self.waiting_targets = defaultdict(list)
        self.revocations = ExpiringDict(1000, 300)
        self.iftoken2seg = defaultdict(set)
        self.CTRL_PLD_CLASS_MAP = {
            PayloadClass.PATH: {
                PMT.REQUEST: self.path_resolution,
                PMT.REPLY: self.handle_path_segment_record,
                PMT.REG: self.handle_path_segment_record,
                PMT.REVOCATION: self._handle_revocation,
                PMT.SYNC: self.handle_path_segment_record,
            },
        }
        self._segs_to_zk = deque()
        # Add more IPs here if we support dual-stack
        name_addrs = "\0".join(
            [self.id, str(SCION_UDP_PORT),
             str(self.addr.host)])
        self.zk = Zookeeper(self.topology.isd_as, PATH_SERVICE, name_addrs,
                            self.topology.zookeepers)
        self.zk.retry("Joining party", self.zk.party_setup)
        self.path_cache = ZkSharedCache(self.zk, self.ZK_PATH_CACHE_PATH,
                                        self._cached_entries_handler)

    def worker(self):
        """
        Worker thread that takes care of reading shared paths from ZK, and
        handling master election for core servers.
        """
        worker_cycle = 1.0
        start = SCIONTime.get_time()
        while self.run_flag.is_set():
            sleep_interval(start, worker_cycle, "cPS.worker cycle",
                           self._quiet_startup())
            start = SCIONTime.get_time()
            try:
                self.zk.wait_connected()
                self.path_cache.process()
                # Try to become a master.
                is_master = self.zk.get_lock(lock_timeout=0, conn_timeout=0)
                if is_master:
                    self.path_cache.expire(self.config.propagation_time * 10)
            except ZkNoConnection:
                logging.warning('worker(): ZkNoConnection')
                pass
            self._update_master()
            self._propagate_and_sync()

    def _cached_entries_handler(self, raw_entries):
        """
        Handles cached through ZK entries, passed as a list.
        """
        count = 0
        for raw in raw_entries:
            recs = PathSegmentRecords.from_raw(raw)
            for type_, pcb in recs.iter_pcbs():
                count += 1
                self._dispatch_segment_record(type_, pcb, from_zk=True)
        if count:
            logging.debug("Processed %s PCBs from ZK", count)

    def _update_master(self):
        pass

    def _add_if_mappings(self, pcb):
        """
        Add if revocation token to segment ID mappings.
        """
        segment_id = pcb.get_hops_hash()
        for asm in pcb.p.asms:
            self.iftoken2seg[asm.pcbms[0].igRevToken].add(segment_id)
            self.iftoken2seg[asm.egRevToken].add(segment_id)
            for pm in asm.pcbms:
                self.iftoken2seg[pm.igRevToken].add(segment_id)

    @abstractmethod
    def _handle_up_segment_record(self, pcb, **kwargs):
        raise NotImplementedError

    @abstractmethod
    def _handle_down_segment_record(self, pcb, **kwargs):
        raise NotImplementedError

    @abstractmethod
    def _handle_core_segment_record(self, pcb, **kwargs):
        raise NotImplementedError

    def _add_segment(self, pcb, seg_db, name, reverse=False):
        res = seg_db.update(pcb, reverse=reverse)
        if res == DBResult.ENTRY_ADDED:
            self._add_if_mappings(pcb)
            logging.info("%s-Segment registered: %s", name, pcb.short_desc())
            return True
        elif res == DBResult.ENTRY_UPDATED:
            logging.debug("%s-Segment updated: %s", name, pcb.short_desc())
        return False

    def _handle_revocation(self, pkt):
        """
        Handles a revocation of a segment, interface or hop.

        :param pkt: The packet containing the revocation info.
        :type pkt: PathMgmtPacket
        """
        rev_info = pkt.get_payload()
        assert isinstance(rev_info, RevocationInfo)
        if hash(rev_info) in self.revocations:
            logging.debug("Already received revocation. Dropping...")
            return
        else:
            self.revocations[hash(rev_info)] = rev_info
            logging.debug("Received revocation from %s:\n%s", pkt.addrs.src,
                          rev_info)
        # Remove segments that contain the revoked interface.
        self._remove_revoked_segments(rev_info)

    def _remove_revoked_segments(self, rev_info):
        """
        Remove segments that contain a revoked interface. Checks 20 tokens in
        case previous revocations were missed by the PS.

        :param rev_info: The revocation info
        :type rev_info: RevocationInfo
        """
        rev_token = rev_info.rev_token
        for _ in range(self.N_TOKENS_CHECK):
            rev_token = SHA256.new(rev_token).digest()
            segments = self.iftoken2seg[rev_token]
            while segments:
                sid = segments.pop()
                # Delete segment from DB.
                self.down_segments.delete(sid)
                self.core_segments.delete(sid)
            if rev_token in self.iftoken2seg:
                del self.iftoken2seg[rev_token]

    def _send_to_next_hop(self, pkt, if_id):
        """
        Sends the packet to the next hop of the given if_id.
        :param if_id: The interface ID of the corresponding interface.
        :type if_id: int.
        """
        if if_id not in self.ifid2er:
            logging.error("Unknown Interface ID: %d", if_id)
            return
        next_hop = self.ifid2er[if_id].addr
        self.send(pkt, next_hop)

    def _send_path_segments(self, pkt, up=None, core=None, down=None):
        """
        Sends path-segments to requester (depending on Path Server's location).
        """
        up = up or set()
        core = core or set()
        down = down or set()
        if not (up | core | down):
            logging.warning("No segments to send")
            return
        req = pkt.get_payload()
        rep_pkt = pkt.reversed_copy()
        rep_pkt.set_payload(
            PathRecordsReply.from_values(
                {
                    PST.UP: up,
                    PST.CORE: core,
                    PST.DOWN: down
                }, ))
        rep_pkt.addrs.src.host = self.addr.host
        next_hop, port = self.get_first_hop(rep_pkt)
        if next_hop is None:
            logging.error("Next hop is None for Interface %s",
                          rep_pkt.path.get_fwd_if())
            return
        logging.info(
            "Sending PATH_REPLY with %d segment(s) to:%s "
            "port:%s in response to: %s",
            len(up | core | down),
            rep_pkt.addrs.dst,
            rep_pkt.l4_hdr.dst_port,
            req.short_desc(),
        )
        self.send(rep_pkt, next_hop, port)

    def _handle_pending_requests(self, dst_ia, sibra):
        to_remove = []
        key = dst_ia, sibra
        # Serve pending requests.
        for pkt in self.pending_req[key]:
            if self.path_resolution(pkt, new_request=False):
                to_remove.append(pkt)
        # Clean state.
        for pkt in to_remove:
            self.pending_req[key].remove(pkt)
        if not self.pending_req[key]:
            del self.pending_req[key]

    def handle_path_segment_record(self, pkt):
        seg_recs = pkt.get_payload()
        params = self._dispatch_params(pkt)
        added = set()
        for type_, pcb in seg_recs.iter_pcbs():
            added.update(self._dispatch_segment_record(type_, pcb, **params))
        # Handling pending requests, basing on added segments.
        for dst_ia, sibra in added:
            self._handle_pending_requests(dst_ia, sibra)

    def _dispatch_segment_record(self, type_, seg, **kwargs):
        handle_map = {
            PST.UP: self._handle_up_segment_record,
            PST.CORE: self._handle_core_segment_record,
            PST.DOWN: self._handle_down_segment_record,
        }
        return handle_map[type_](seg, **kwargs)

    def _dispatch_params(self, pkt):
        return {}

    def _propagate_and_sync(self):
        self._share_via_zk()

    def _gen_prop_recs(self, queue, limit=PROP_LIMIT):
        count = 0
        pcbs = defaultdict(list)
        while queue:
            count += 1
            type_, pcb = queue.popleft()
            pcbs[type_].append(pcb.copy())
            if count >= limit:
                yield (pcbs)
                count = 0
                pcbs = defaultdict(list)
        if pcbs:
            yield (pcbs)

    @abstractmethod
    def path_resolution(self, path_request):
        """
        Handles all types of path request.
        """
        raise NotImplementedError

    def _handle_waiting_targets(self, pcb, reverse=False):
        """
        Handle any queries that are waiting for a path to any core AS in an ISD.
        """
        dst_ia = pcb.first_ia()
        if reverse:
            dst_ia = pcb.last_ia()
        if not self.is_core_as(dst_ia):
            logging.warning("Invalid waiting target, not a core AS: %s",
                            dst_ia)
            return
        self._send_waiting_queries(dst_ia[0], pcb)

    def _send_waiting_queries(self, dst_isd, pcb):
        targets = self.waiting_targets[dst_isd]
        if not targets:
            return
        path = pcb.get_path(reverse_direction=True)
        src_ia = pcb.first_ia()
        while targets:
            seg_req = targets.pop(0)
            req_pkt = self._build_packet(SVCType.PS,
                                         dst_ia=src_ia,
                                         path=path,
                                         payload=seg_req)
            self._send_to_next_hop(req_pkt, path.get_fwd_if())
            logging.info("Waiting request (%s) sent via %s",
                         seg_req.short_desc(), pcb.short_desc())

    def _share_via_zk(self):
        if not self._segs_to_zk:
            return
        logging.info("Sharing %d segment(s) via ZK", len(self._segs_to_zk))
        for pcb_dict in self._gen_prop_recs(self._segs_to_zk,
                                            limit=self.ZK_SHARE_LIMIT):
            seg_recs = PathSegmentRecords.from_values(pcb_dict)
            self._zk_write(seg_recs.pack())

    def _zk_write(self, data):
        hash_ = SHA256.new(data).hexdigest()
        try:
            self.path_cache.store("%s-%s" % (hash_, SCIONTime.get_time()),
                                  data)
        except ZkNoConnection:
            logging.warning("Unable to store segment(s) in shared path: "
                            "no connection to ZK")

    def run(self):
        """
        Run an instance of the Path Server.
        """
        threading.Thread(target=thread_safety_net,
                         args=(self.worker, ),
                         name="PS.worker",
                         daemon=True).start()
        super().run()
Ejemplo n.º 25
0
 def test_not_present(self):
     pth_seg_db = PathSegmentDB()
     pth_seg_db._db = create_mock()
     pth_seg_db._db.return_value = False
     ntools.eq_(pth_seg_db.delete("data"), DBResult.NONE)
Ejemplo n.º 26
0
class PathServer(SCIONElement, metaclass=ABCMeta):
    """
    The SCION Path Server.
    """
    SERVICE_TYPE = PATH_SERVICE
    MAX_SEG_NO = 5  # TODO: replace by config variable.
    # ZK path for incoming PATHs
    ZK_PATH_CACHE_PATH = "path_cache"
    # ZK path for incoming REVs
    ZK_REV_CACHE_PATH = "rev_cache"
    # Max number of segments per propagation packet
    PROP_LIMIT = 5
    # Max number of segments per ZK cache entry
    ZK_SHARE_LIMIT = 10
    # Time to store revocations in zookeeper
    ZK_REV_OBJ_MAX_AGE = HASHTREE_EPOCH_TIME
    # TTL of segments in the queue for ZK (in seconds)
    SEGS_TO_ZK_TTL = 10 * 60

    def __init__(self,
                 server_id,
                 conf_dir,
                 spki_cache_dir=GEN_CACHE_PATH,
                 prom_export=None):
        """
        :param str server_id: server identifier.
        :param str conf_dir: configuration directory.
        :param str prom_export: prometheus export address.
        """
        super().__init__(server_id,
                         conf_dir,
                         spki_cache_dir=spki_cache_dir,
                         prom_export=prom_export)
        self.config = self._load_as_conf()
        down_labels = {
            **self._labels, "type": "down"
        } if self._labels else None
        core_labels = {
            **self._labels, "type": "core"
        } if self._labels else None
        self.down_segments = PathSegmentDB(max_res_no=self.MAX_SEG_NO,
                                           labels=down_labels)
        self.core_segments = PathSegmentDB(max_res_no=self.MAX_SEG_NO,
                                           labels=core_labels)
        # Dict of pending requests.
        self.pending_req = defaultdict(
            lambda: ExpiringDict(1000, PATH_REQ_TOUT))
        self.pen_req_lock = threading.Lock()
        self._request_logger = None
        # Used when l/cPS doesn't have up/dw-path.
        self.waiting_targets = defaultdict(list)
        self.revocations = RevCache(labels=self._labels)
        # A mapping from (hash tree root of AS, IFID) to segments
        self.htroot_if2seg = ExpiringDict(1000,
                                          self.config.revocation_tree_ttl)
        self.htroot_if2seglock = Lock()
        self.CTRL_PLD_CLASS_MAP = {
            PayloadClass.PATH: {
                PMT.IFSTATE_INFOS: self.handle_ifstate_infos,
                PMT.REQUEST: self.path_resolution,
                PMT.REPLY: self.handle_path_reply,
                PMT.REG: self.handle_seg_recs,
                PMT.REVOCATION: self._handle_revocation,
                PMT.SYNC: self.handle_seg_recs,
            },
            PayloadClass.CERT: {
                CertMgmtType.CERT_CHAIN_REQ: self.process_cert_chain_request,
                CertMgmtType.CERT_CHAIN_REPLY: self.process_cert_chain_reply,
                CertMgmtType.TRC_REPLY: self.process_trc_reply,
                CertMgmtType.TRC_REQ: self.process_trc_request,
            },
        }
        self.SCMP_PLD_CLASS_MAP = {
            SCMPClass.PATH: {
                SCMPPathClass.REVOKED_IF: self._handle_scmp_revocation,
            },
        }
        self._segs_to_zk = ExpiringDict(1000, self.SEGS_TO_ZK_TTL)
        self._revs_to_zk = ExpiringDict(1000, HASHTREE_EPOCH_TIME)
        self._zkid = ZkID.from_values(self.addr.isd_as, self.id,
                                      [(self.addr.host, self._port)])
        self.zk = Zookeeper(self.topology.isd_as, PATH_SERVICE,
                            self._zkid.copy().pack(), self.topology.zookeepers)
        self.zk.retry("Joining party", self.zk.party_setup)
        self.path_cache = ZkSharedCache(self.zk, self.ZK_PATH_CACHE_PATH,
                                        self._handle_paths_from_zk)
        self.rev_cache = ZkSharedCache(self.zk, self.ZK_REV_CACHE_PATH,
                                       self._rev_entries_handler)
        self._init_request_logger()

    def worker(self):
        """
        Worker thread that takes care of reading shared paths from ZK, and
        handling master election for core servers.
        """
        worker_cycle = 1.0
        start = SCIONTime.get_time()
        while self.run_flag.is_set():
            sleep_interval(start, worker_cycle, "cPS.worker cycle",
                           self._quiet_startup())
            start = SCIONTime.get_time()
            try:
                self.zk.wait_connected()
                self.path_cache.process()
                self.rev_cache.process()
                # Try to become a master.
                ret = self.zk.get_lock(lock_timeout=0, conn_timeout=0)
                if ret:  # Either got the lock, or already had it.
                    if ret == ZK_LOCK_SUCCESS:
                        logging.info("Became master")
                    self.path_cache.expire(self.config.propagation_time * 10)
                    self.rev_cache.expire(self.ZK_REV_OBJ_MAX_AGE)
            except ZkNoConnection:
                logging.warning('worker(): ZkNoConnection')
                pass
            self._update_master()
            self._propagate_and_sync()
            self._handle_pending_requests()
            self._update_metrics()

    def _update_master(self):
        pass

    def _rev_entries_handler(self, raw_entries):
        for raw in raw_entries:
            rev_info = RevocationInfo.from_raw(raw)
            try:
                rev_info.validate()
            except SCIONBaseError as e:
                logging.warning("Failed to validate RevInfo from zk: %s\n%s",
                                e, rev_info.short_desc())
                continue
            self._remove_revoked_segments(rev_info)

    def _add_rev_mappings(self, pcb):
        """
        Add if revocation token to segment ID mappings.
        """
        segment_id = pcb.get_hops_hash()
        with self.htroot_if2seglock:
            for asm in pcb.iter_asms():
                hof = asm.pcbm(0).hof()
                egress_h = (asm.p.hashTreeRoot, hof.egress_if)
                self.htroot_if2seg.setdefault(egress_h, set()).add(segment_id)
                ingress_h = (asm.p.hashTreeRoot, hof.ingress_if)
                self.htroot_if2seg.setdefault(ingress_h, set()).add(segment_id)

    @abstractmethod
    def _handle_up_segment_record(self, pcb, **kwargs):
        raise NotImplementedError

    @abstractmethod
    def _handle_down_segment_record(self, pcb, **kwargs):
        raise NotImplementedError

    @abstractmethod
    def _handle_core_segment_record(self, pcb, **kwargs):
        raise NotImplementedError

    def _add_segment(self, pcb, seg_db, name, reverse=False):
        res = seg_db.update(pcb, reverse=reverse)
        if res == DBResult.ENTRY_ADDED:
            self._add_rev_mappings(pcb)
            logging.info("%s-Segment registered: %s", name, pcb.short_id())
            return True
        elif res == DBResult.ENTRY_UPDATED:
            self._add_rev_mappings(pcb)
            logging.debug("%s-Segment updated: %s", name, pcb.short_id())
        return False

    def handle_ifstate_infos(self, cpld, meta):
        """
        Handles IFStateInfos.

        :param IFStatePayload infos: The state info objects.
        """
        pmgt = cpld.union
        infos = pmgt.union
        assert isinstance(infos, IFStatePayload), type(infos)
        for info in infos.iter_infos():
            if not info.p.active and info.p.revInfo:
                rev_info = info.rev_info()
                try:
                    rev_info.validate()
                except SCIONBaseError as e:
                    logging.warning(
                        "Failed to validate IFStateInfo RevInfo from %s: %s\n%s",
                        meta, e, rev_info.short_desc())
                    continue
                self._handle_revocation(CtrlPayload(PathMgmt(info.rev_info())),
                                        meta)

    def _handle_scmp_revocation(self, pld, meta):
        rev_info = RevocationInfo.from_raw(pld.info.rev_info)
        try:
            rev_info.validate()
        except SCIONBaseError as e:
            logging.warning("Failed to validate SCMP RevInfo from %s: %s\n%s",
                            meta, e, rev_info.short_desc())
            return
        self._handle_revocation(CtrlPayload(PathMgmt(rev_info)), meta)

    def _handle_revocation(self, cpld, meta):
        """
        Handles a revocation of a segment, interface or hop.

        :param rev_info: The RevocationInfo object.
        """
        pmgt = cpld.union
        rev_info = pmgt.union
        assert isinstance(rev_info, RevocationInfo), type(rev_info)
        # Validate before checking for presense in self.revocations, as that will trigger an assert
        # failure if the rev_info is invalid.
        try:
            rev_info.validate()
        except SCIONBaseError as e:
            # Validation already done in the IFStateInfo and SCMP paths, so a failure here means
            # it's from a CtrlPld.
            logging.warning(
                "Failed to validate CtrlPld RevInfo from %s: %s\n%s", meta, e,
                rev_info.short_desc())
            return

        if rev_info in self.revocations:
            return
        logging.debug("Received revocation from %s: %s", meta,
                      rev_info.short_desc())
        try:
            rev_info.validate()
        except SCIONBaseError as e:
            logging.warning("Failed to validate RevInfo from %s: %s", meta, e)
            return
        if meta.ia[0] != self.addr.isd_as[0]:
            logging.info(
                "Dropping revocation received from a different ISD. Src: %s RevInfo: %s"
                % (meta, rev_info.short_desc()))
            return
        self.revocations.add(rev_info)
        self._revs_to_zk[rev_info] = rev_info.copy().pack(
        )  # have to pack copy
        # Remove segments that contain the revoked interface.
        self._remove_revoked_segments(rev_info)
        # Forward revocation to other path servers.
        self._forward_revocation(rev_info, meta)

    def _remove_revoked_segments(self, rev_info):
        """
        Try the previous and next hashes as possible astokens,
        and delete any segment that matches

        :param rev_info: The revocation info
        :type rev_info: RevocationInfo
        """
        if ConnectedHashTree.verify_epoch(
                rev_info.p.epoch) != ConnectedHashTree.EPOCH_OK:
            return
        (hash01, hash12) = ConnectedHashTree.get_possible_hashes(rev_info)
        if_id = rev_info.p.ifID

        with self.htroot_if2seglock:
            down_segs_removed = 0
            core_segs_removed = 0
            up_segs_removed = 0
            for h in (hash01, hash12):
                for sid in self.htroot_if2seg.pop((h, if_id), []):
                    if self.down_segments.delete(
                            sid) == DBResult.ENTRY_DELETED:
                        down_segs_removed += 1
                    if self.core_segments.delete(
                            sid) == DBResult.ENTRY_DELETED:
                        core_segs_removed += 1
                    if not self.topology.is_core_as:
                        if (self.up_segments.delete(sid) ==
                                DBResult.ENTRY_DELETED):
                            up_segs_removed += 1
            logging.debug(
                "Removed segments revoked by [%s]: UP: %d DOWN: %d CORE: %d" %
                (rev_info.short_desc(), up_segs_removed, down_segs_removed,
                 core_segs_removed))

    @abstractmethod
    def _forward_revocation(self, rev_info, meta):
        """
        Forwards a revocation to other path servers that need to be notified.

        :param rev_info: The RevInfo object.
        :param meta: The MessageMeta object.
        """
        raise NotImplementedError

    def _send_path_segments(self,
                            req,
                            req_id,
                            meta,
                            logger,
                            up=None,
                            core=None,
                            down=None):
        """
        Sends path-segments to requester (depending on Path Server's location).
        """
        up = up or set()
        core = core or set()
        down = down or set()
        all_segs = up | core | down
        if not all_segs:
            logger.warning("No segments to send for request: %s from: %s" %
                           (req.short_desc(), meta))
            return
        revs_to_add = self._peer_revs_for_segs(all_segs)
        recs = PathSegmentRecords.from_values(
            {
                PST.UP: up,
                PST.CORE: core,
                PST.DOWN: down
            }, revs_to_add)
        pld = PathSegmentReply.from_values(req.copy(), recs)
        self.send_meta(CtrlPayload(PathMgmt(pld), req_id=req_id), meta)
        logger.info("Sending PATH_REPLY with %d segment(s).", len(all_segs))

    def _peer_revs_for_segs(self, segs):
        """Returns a list of peer revocations for segments in 'segs'."""
        def _handle_one_seg(seg):
            for asm in seg.iter_asms():
                for pcbm in asm.iter_pcbms(1):
                    hof = pcbm.hof()
                    for if_id in [hof.ingress_if, hof.egress_if]:
                        rev_info = self.revocations.get((asm.isd_as(), if_id))
                        if rev_info:
                            revs_to_add.add(rev_info.copy())
                            return

        revs_to_add = set()
        for seg in segs:
            _handle_one_seg(seg)

        return list(revs_to_add)

    def _handle_pending_requests(self):
        rem_keys = []
        # Serve pending requests.
        with self.pen_req_lock:
            for key in self.pending_req:
                for req_key, (req, req_id, meta,
                              logger) in self.pending_req[key].items():
                    if self.path_resolution(CtrlPayload(PathMgmt(req),
                                                        req_id=req_id),
                                            meta,
                                            new_request=False,
                                            logger=logger):
                        meta.close()
                        del self.pending_req[key][req_key]
                if not self.pending_req[key]:
                    rem_keys.append(key)
            for key in rem_keys:
                del self.pending_req[key]

    def _handle_paths_from_zk(self, raw_entries):
        """
        Handles cached paths through ZK, passed as a list.
        """
        for raw in raw_entries:
            recs = PathSegmentRecords.from_raw(raw)
            for type_, pcb in recs.iter_pcbs():
                seg_meta = PathSegMeta(pcb,
                                       self.continue_seg_processing,
                                       type_=type_,
                                       params={'from_zk': True})
                self._process_path_seg(seg_meta)
        if raw_entries:
            logging.debug("Processed %s segments from ZK", len(raw_entries))

    def handle_path_reply(self, cpld, meta):
        pmgt = cpld.union
        reply = pmgt.union
        assert isinstance(reply, PathSegmentReply), type(reply)
        self._handle_seg_recs(reply.recs(), cpld.req_id, meta)

    def handle_seg_recs(self, cpld, meta):
        pmgt = cpld.union
        seg_recs = pmgt.union
        self._handle_seg_recs(seg_recs, cpld.req_id, meta)

    def _handle_seg_recs(self, seg_recs, req_id, meta):
        """
        Handles paths received from the network.
        """
        assert isinstance(seg_recs, PathSegmentRecords), type(seg_recs)
        params = self._dispatch_params(seg_recs, meta)
        # Add revocations for peer interfaces included in the path segments.
        for rev_info in seg_recs.iter_rev_infos():
            self.revocations.add(rev_info)
        # Verify pcbs and process them
        for type_, pcb in seg_recs.iter_pcbs():
            seg_meta = PathSegMeta(pcb, self.continue_seg_processing, meta,
                                   type_, params)
            self._process_path_seg(seg_meta, req_id)

    def continue_seg_processing(self, seg_meta):
        """
        For every path segment(that can be verified) received from the network
        or ZK this function gets called to continue the processing for the
        segment.
        The segment is added to pathdb and pending requests are checked.
        """
        pcb = seg_meta.seg
        logging.debug("Successfully verified PCB %s" % pcb.short_id())
        type_ = seg_meta.type
        params = seg_meta.params
        self.handle_ext(pcb)
        self._dispatch_segment_record(type_, pcb, **params)
        self._handle_pending_requests()

    def handle_ext(self, pcb):
        """
        Handle beacon extensions.
        """
        # Handle PCB extensions:
        for asm in pcb.iter_asms():
            pol = asm.routing_pol_ext()
            if pol:
                self.handle_routing_pol_ext(pol)

    def handle_routing_pol_ext(self, ext):
        # TODO(Sezer): Implement extension handling
        logging.debug("Routing policy extension: %s" % ext)

    def _dispatch_segment_record(self, type_, seg, **kwargs):
        # Check that segment does not contain a revoked interface.
        if not self._validate_segment(seg):
            return
        handle_map = {
            PST.UP: self._handle_up_segment_record,
            PST.CORE: self._handle_core_segment_record,
            PST.DOWN: self._handle_down_segment_record,
        }
        handle_map[type_](seg, **kwargs)

    def _validate_segment(self, seg):
        """
        Check segment for revoked upstream/downstream interfaces.

        :param seg: The PathSegment object.
        :return: False, if the path segment contains a revoked upstream/
            downstream interface (not peer). True otherwise.
        """
        for asm in seg.iter_asms():
            pcbm = asm.pcbm(0)
            for if_id in [pcbm.hof().ingress_if, pcbm.hof().egress_if]:
                rev_info = self.revocations.get((asm.isd_as(), if_id))
                if rev_info:
                    logging.debug(
                        "Found revoked interface (%d, %s) in segment %s." %
                        (rev_info.p.ifID, rev_info.isd_as(), seg.short_desc()))
                    return False
        return True

    def _dispatch_params(self, pld, meta):
        return {}

    def _propagate_and_sync(self):
        self._share_via_zk()
        self._share_revs_via_zk()

    def _gen_prop_recs(self, container, limit=PROP_LIMIT):
        count = 0
        pcbs = defaultdict(list)
        while container:
            try:
                _, (type_, pcb) = container.popitem(last=False)
            except KeyError:
                continue
            count += 1
            pcbs[type_].append(pcb.copy())
            if count >= limit:
                yield (pcbs)
                count = 0
                pcbs = defaultdict(list)
        if pcbs:
            yield (pcbs)

    @abstractmethod
    def path_resolution(self,
                        path_request,
                        meta,
                        new_request=True,
                        logger=None):
        """
        Handles all types of path request.
        """
        raise NotImplementedError

    def _handle_waiting_targets(self, pcb):
        """
        Handle any queries that are waiting for a path to any core AS in an ISD.
        """
        dst_ia = pcb.first_ia()
        if not self.is_core_as(dst_ia):
            logging.warning("Invalid waiting target, not a core AS: %s",
                            dst_ia)
            return
        self._send_waiting_queries(dst_ia[0], pcb)

    def _send_waiting_queries(self, dst_isd, pcb):
        targets = self.waiting_targets[dst_isd]
        if not targets:
            return
        path = pcb.get_path(reverse_direction=True)
        src_ia = pcb.first_ia()
        while targets:
            (seg_req, logger) = targets.pop(0)
            meta = self._build_meta(ia=src_ia,
                                    path=path,
                                    host=SVCType.PS_A,
                                    reuse=True)
            req_id = mk_ctrl_req_id()
            self.send_meta(CtrlPayload(PathMgmt(seg_req), req_id=req_id), meta)
            logger.info("Waiting request (%s) sent to %s via %s [id: %016x]",
                        seg_req.short_desc(), meta, pcb.short_desc(), req_id)

    def _share_via_zk(self):
        if not self._segs_to_zk:
            return
        logging.info("Sharing %d segment(s) via ZK", len(self._segs_to_zk))
        for pcb_dict in self._gen_prop_recs(self._segs_to_zk,
                                            limit=self.ZK_SHARE_LIMIT):
            seg_recs = PathSegmentRecords.from_values(pcb_dict)
            self._zk_write(seg_recs.pack())

    def _share_revs_via_zk(self):
        if not self._revs_to_zk:
            return
        logging.info("Sharing %d revocation(s) via ZK", len(self._revs_to_zk))
        while self._revs_to_zk:
            try:
                data = self._revs_to_zk.popitem(last=False)[1]
            except KeyError:
                continue
            self._zk_write_rev(data)

    def _zk_write(self, data):
        hash_ = crypto_hash(data).hex()
        try:
            self.path_cache.store("%s-%s" % (hash_, SCIONTime.get_time()),
                                  data)
        except ZkNoConnection:
            logging.warning("Unable to store segment(s) in shared path: "
                            "no connection to ZK")

    def _zk_write_rev(self, data):
        hash_ = crypto_hash(data).hex()
        try:
            self.rev_cache.store("%s-%s" % (hash_, SCIONTime.get_time()), data)
        except ZkNoConnection:
            logging.warning("Unable to store revocation(s) in shared path: "
                            "no connection to ZK")

    def _init_request_logger(self):
        """
        Initializes the request logger.
        """
        self._request_logger = logging.getLogger("RequestLogger")
        # Create new formatter to include the request in the log.
        formatter = formatter = Rfc3339Formatter(
            "%(asctime)s [%(levelname)s] (%(threadName)s) %(message)s "
            "{id=%(id)s, from=%(from)s}")
        add_formatter('RequestLogger', formatter)

    def get_request_logger(self, req_id, meta):
        """
        Returns a logger adapter for a request.
        """
        # Create a logger for the request to log with context.
        return logging.LoggerAdapter(self._request_logger, {
            "id": req_id,
            "from": str(meta)
        })

    def _init_metrics(self):
        super()._init_metrics()
        REQS_TOTAL.labels(**self._labels).inc(0)
        REQS_PENDING.labels(**self._labels).set(0)
        SEGS_TO_ZK.labels(**self._labels).set(0)
        REVS_TO_ZK.labels(**self._labels).set(0)
        HT_ROOT_MAPPTINGS.labels(**self._labels).set(0)
        IS_MASTER.labels(**self._labels).set(0)

    def _update_metrics(self):
        """
        Updates all Gauge metrics. Subclass can update their own metrics but must
        call the superclass' implementation.
        """
        if not self._labels:
            return
        # Update pending requests metric.
        # XXX(shitz): This could become a performance problem should there ever be
        # a large amount of pending requests (>100'000).
        total_pending = 0
        with self.pen_req_lock:
            for reqs in self.pending_req.values():
                total_pending += len(reqs)
        REQS_PENDING.labels(**self._labels).set(total_pending)
        # Update SEGS_TO_ZK and REVS_TO_ZK metrics.
        SEGS_TO_ZK.labels(**self._labels).set(len(self._segs_to_zk))
        REVS_TO_ZK.labels(**self._labels).set(len(self._revs_to_zk))
        # Update HT_ROOT_MAPPTINGS metric.
        HT_ROOT_MAPPTINGS.labels(**self._labels).set(len(self.htroot_if2seg))
        # Update IS_MASTER metric.
        IS_MASTER.labels(**self._labels).set(int(self.zk.have_lock()))

    def run(self):
        """
        Run an instance of the Path Server.
        """
        threading.Thread(target=thread_safety_net,
                         args=(self.worker, ),
                         name="PS.worker",
                         daemon=True).start()
        threading.Thread(target=thread_safety_net,
                         args=(self._check_trc_cert_reqs, ),
                         name="Elem.check_trc_cert_reqs",
                         daemon=True).start()
        super().run()
Ejemplo n.º 27
0
class SibraServerBase(SCIONElement):
    """
    Base class for the SIBRA service, which is responsible for managing steady
    paths on all interfaces in the local AS.
    """
    SERVICE_TYPE = ServiceType.SIBRA
    PST_TYPE = None

    def __init__(self, server_id, conf_dir, prom_export=None):
        """
        :param str server_id: server identifier.
        :param str conf_dir: configuration directory.
        :param str prom_export: prometheus export address.
        """
        super().__init__(server_id, conf_dir, prom_export=prom_export)
        self.sendq = Queue()
        self.signing_key = get_sig_key(self.conf_dir)
        self.segments = PathSegmentDB(max_res_no=1)
        # Maps of {ISD-AS: {steady path id: steady path}} for all incoming
        # (srcs) and outgoing (dests) steady paths:
        self.srcs = {}
        self.dests = {}
        # Map of SibraState objects by interface ID
        self.link_states = {}
        # Map of link types by interface ID
        self.link_types = {}
        self.lock = threading.Lock()
        self.CTRL_PLD_CLASS_MAP = {
            PayloadClass.PATH: {
                PMT.REG: self.handle_path_reg
            },
            PayloadClass.SIBRA: {
                PayloadClass.SIBRA: self.handle_sibra_pkt
            },
        }
        self._find_links()
        zkid = ZkID.from_values(self.addr.isd_as, self.id,
                                [(self.addr.host, self._port)]).pack()
        self.zk = Zookeeper(self.addr.isd_as, self.SERVICE_TYPE, zkid,
                            self.topology.zookeepers)
        self.zk.retry("Joining party", self.zk.party_setup)

    def _find_links(self):
        for br in self.topology.border_routers:
            for ifid, intf in br.interfaces.items():
                self.link_states[ifid] = SibraState(intf.bandwidth,
                                                    self.addr.isd_as)
                self.link_types[ifid] = intf.link_type

    def run(self):
        threading.Thread(target=thread_safety_net,
                         args=(self.worker, ),
                         name="SB.worker",
                         daemon=True).start()
        threading.Thread(target=thread_safety_net,
                         args=(self.sender, ),
                         name="SB.sender",
                         daemon=True).start()
        super().run()

    def worker(self):
        # Cycle time should be << SIBRA_TICK, as it determines how often
        # reservations are potentially renewed, and the expiration of old
        # reservation blocks.
        worker_cycle = 1.0
        start = SCIONTime.get_time()
        while self.run_flag.is_set():
            sleep_interval(start, worker_cycle, "SB.worker cycle")
            start = SCIONTime.get_time()
            with self.lock:
                self.manage_steady_paths()

    def sender(self):
        """
        Handle sending packets on behalf of Link/SteadyPath objects through the
        local socket.
        """
        while self.run_flag.is_set():
            spkt = self.sendq.get()
            dst, port = self.get_first_hop(spkt)
            if not dst:
                logging.error("Unable to determine first hop for packet:\n%s",
                              spkt)
                continue
            spkt.addrs.src.host = self.addr.host
            logging.debug("Dst: %s Port: %s\n%s", dst, port, spkt)
            self.send(spkt, dst, port)

    def handle_path_reg(self, cpld, meta):
        """
        Handle path registration packets from the local beacon service. First
        determine which interface the segments use, then pass the segment to the
        appropriate Link.
        """
        pmgt = cpld.union
        payload = pmgt.union
        assert isinstance(payload, PathRecordsReg), type(payload)
        meta.close()
        name = PST.to_str(self.PST_TYPE)
        with self.lock:
            for type_, pcb in payload.iter_pcbs():
                if type_ == self.PST_TYPE:
                    self._add_segment(pcb, name)

    def _add_segment(self, pcb, name):
        res = self.segments.update(pcb)
        if res == DBResult.ENTRY_ADDED:
            logging.info("%s Segment added: %s", name, pcb.short_desc())
        elif res == DBResult.ENTRY_UPDATED:
            logging.debug("%s Segment updated: %s", name, pcb.short_desc())
        isd_as = pcb.first_ia()
        if isd_as not in self.dests:
            logging.debug("Found new destination ISD-AS: %s", isd_as)
            self.dests[isd_as] = {}
        for steady in self.dests[isd_as].values():
            steady.update_seg(pcb)

    def handle_sibra_pkt(self, pkt):
        """
        Handle SIBRA packets. First determine which interface they came from,
        then pass them to the appropriate Link.
        """
        ext = find_ext_hdr(pkt, ExtensionClass.HOP_BY_HOP,
                           SibraExtSteady.EXT_TYPE)
        if not ext:
            logging.error("Packet contains no SIBRA extension header")
            return
        if not ext.steady:
            logging.error("Received non-steady SIBRA packet:\n%s", pkt)
            return
        if not ext.req_block:
            logging.error("Received non-request SIBRA packet:\n%s", pkt)
            return
        with self.lock:
            if ext.fwd:
                self._process_req(pkt, ext)
            else:
                self._process_reply(pkt, ext)

    def _process_req(self, pkt, ext):
        """Process a steady path request."""
        path_id = ext.path_ids[0]
        self.srcs.setdefault(ext.src_ia, {})
        if ext.setup and path_id in self.srcs[ext.src_ia]:
            logging.error("Setup request for existing path id: %s\n%s",
                          hex_str(path_id), pkt)
            return
        elif not ext.setup and path_id not in self.srcs[ext.src_ia]:
            logging.error("Renewal request for non-existant path id: %s\n%s",
                          hex_str(path_id), pkt)
            return
        ifid = find_last_ifid(pkt, ext)
        if ifid not in self.link_states:
            logging.error("Packet came from unknown interface '%s':\n%s", ifid,
                          pkt)
            return
        if not ext.accepted:
            # Request was already rejected, so just send the packet back.
            pkt.reverse()
            self.sendq.put(pkt)
            return
        state = self.link_states[ifid]
        req_info = ext.req_block.info
        bwsnap = req_info.bw.to_snap()
        bwhint = state.add_steady(path_id, req_info.index, bwsnap,
                                  req_info.exp_tick, True, ext.setup)
        if bwhint is not None:
            # This shouldn't happen - if the local BR accepted the reservation,
            # then there should be enough bandwidth available for it. This means
            # our state is out of sync.
            logging.critical("Requested: %s Available bandwidth: %s\n%s",
                             bwsnap, bwhint, pkt)
            return
        self.srcs[ext.src_ia][path_id] = None
        # All is good, return the packet to the requestor.
        pkt.reverse()
        self.sendq.put(pkt)

    def _process_reply(self, pkt, ext):
        """Process a reply to a steady path request."""
        path_id = ext.path_ids[0]
        dest = pkt.addrs.src.isd_as
        steady = self.dests[dest].get(path_id, None)
        if not steady:
            logging.error("Unknown path ID: %s:\n%s", hex_str(path_id), pkt)
            return
        steady.process_reply(pkt, ext)

    def manage_steady_paths(self):
        """Create or renew steady paths to all destinations."""
        now = time.time()
        for isd_as, steadies in self.dests.items():
            if not steadies and (now - self._startup >= STARTUP_WAIT):
                self._steady_add(isd_as)
                continue
            for id_, steady in list(steadies.items()):
                try:
                    steady.renew()
                except SteadyPathErrorNoReservation:
                    del steadies[id_]

    def _steady_add(self, isd_as):
        seg = self._pick_seg(isd_as)
        if not seg:
            del self.dests[isd_as]
            return
        ifid = seg.last_hof().ingress_if
        link_state = self.link_states[ifid]
        link_type = self.link_types[ifid]
        # FIXME(kormat): un-hardcode these bandwidths
        bwsnap = BWSnapshot(500 * 1024, 500 * 1024)
        steady = SteadyPath(self.addr, self._port, self.sendq,
                            self.signing_key, link_type, link_state, seg,
                            bwsnap)
        self.dests[isd_as][steady.id] = steady
        logging.debug("Setting up steady path %s -> %s over %s",
                      self.addr.isd_as, isd_as, seg.short_desc())
        steady.setup()

    def _pick_seg(self, isd_as):
        """Select the segment to use for a steady path."""
        # FIXME(kormat): this needs actual logic
        # For now, we use the shortest path
        segs = self.segments(first_ia=isd_as)
        if segs:
            return segs[0]
        if not self._quiet_startup():
            logging.warning("No segments to %s", isd_as)