Пример #1
0
    def post_dissection(self, r):
        if not self.tls_session.frozen and self.server_share.pubkey:
            # if there is a pubkey, we assume the crypto library is ok
            pubshare = self.tls_session.tls13_server_pubshare
            if pubshare:
                pkt_info = r.firstlayer().summary()
                log_runtime.info("TLS: overwriting previous server key share [%s]", pkt_info)  # noqa: E501
            group_name = _tls_named_groups[self.server_share.group]
            pubshare[group_name] = self.server_share.pubkey

            if group_name in self.tls_session.tls13_client_privshares:
                pubkey = self.server_share.pubkey
                privkey = self.tls_session.tls13_client_privshares[group_name]
                if group_name in six.itervalues(_tls_named_ffdh_groups):
                    pms = privkey.exchange(pubkey)
                elif group_name in six.itervalues(_tls_named_curves):
                    if group_name in ["x25519", "x448"]:
                        pms = privkey.exchange(pubkey)
                    else:
                        pms = privkey.exchange(ec.ECDH(), pubkey)
                self.tls_session.tls13_dhe_secret = pms
            elif group_name in self.tls_session.tls13_server_privshare:
                pubkey = self.tls_session.tls13_client_pubshares[group_name]
                privkey = self.tls_session.tls13_server_privshare[group_name]
                if group_name in six.itervalues(_tls_named_ffdh_groups):
                    pms = privkey.exchange(pubkey)
                elif group_name in six.itervalues(_tls_named_curves):
                    if group_name in ["x25519", "x448"]:
                        pms = privkey.exchange(pubkey)
                    else:
                        pms = privkey.exchange(ec.ECDH(), pubkey)
                self.tls_session.tls13_dhe_secret = pms
        return super(TLS_Ext_KeyShare_SH, self).post_dissection(r)
Пример #2
0
    def build_graph(self):
        # type: () -> str
        s = 'digraph "%s" {\n' % self.__class__.__name__

        se = ""  # Keep initial nodes at the beginning for better rendering
        for st in six.itervalues(self.states):
            if st.atmt_initial:
                se = (
                    '\t"%s" [ style=filled, fillcolor=blue, shape=box, root=true];\n'
                    % st.atmt_state) + se  # noqa: E501
            elif st.atmt_final:
                se += '\t"%s" [ style=filled, fillcolor=green, shape=octagon ];\n' % st.atmt_state  # noqa: E501
            elif st.atmt_error:
                se += '\t"%s" [ style=filled, fillcolor=red, shape=octagon ];\n' % st.atmt_state  # noqa: E501
            elif st.atmt_stop:
                se += '\t"%s" [ style=filled, fillcolor=orange, shape=box, root=true ];\n' % st.atmt_state  # noqa: E501
        s += se

        for st in six.itervalues(self.states):
            for n in st.atmt_origfunc.__code__.co_names + st.atmt_origfunc.__code__.co_consts:  # noqa: E501
                if n in self.states:
                    s += '\t"%s" -> "%s" [ color=green ];\n' % (
                        st.atmt_state, n)  # noqa: E501

        for c, k, v in (
            [("purple", k, v)
             for k, v in self.conditions.items()] +  # noqa: E501
            [("red", k, v)
             for k, v in self.recv_conditions.items()] +  # noqa: E501
            [("orange", k, v) for k, v in self.ioevents.items()]):
            for f in v:
                for n in f.__code__.co_names + f.__code__.co_consts:
                    if n in self.states:
                        line = f.atmt_condname
                        for x in self.actions[f.atmt_condname]:
                            line += "\\l>[%s]" % x.__name__
                        s += '\t"%s" -> "%s" [label="%s", color=%s];\n' % (
                            k, n, line, c)  # noqa: E501
        for k, v2 in six.iteritems(self.timeout):
            for t, f in v2:
                if f is None:
                    continue
                for n in f.__code__.co_names + f.__code__.co_consts:
                    if n in self.states:
                        line = "%s/%.1fs" % (f.atmt_condname, t)
                        for x in self.actions[f.atmt_condname]:
                            line += "\\l>[%s]" % x.__name__
                        s += '\t"%s" -> "%s" [label="%s",color=blue];\n' % (
                            k, n, line)  # noqa: E501
        s += "}\n"
        return s
Пример #3
0
 def getfield(self, pkt, s):
     if (pkt.tls_session.rcs.cipher.type != "aead"
             and False in six.itervalues(pkt.tls_session.rcs.cipher.ready)):
         # XXX Find a more proper way to handle the still-encrypted case
         return s, b""
     tmp_len = pkt.tls_session.rcs.mac_len
     return s[tmp_len:], self.m2i(pkt, s[:tmp_len])
Пример #4
0
    def auth_decrypt(self, A, C, seq_num):
        """
        Decrypt the data and verify the authentication code (in this order).
        If the verification fails, an AEADTagError is raised. It is the user's
        responsibility to catch it if deemed useful. If we lack the key, we
        raise a CipherError which contains the encrypted input.
        """
        C, mac = C[:-self.tag_len], C[-self.tag_len:]
        if False in six.itervalues(self.ready):
            raise CipherError(C, mac)

        if hasattr(self, "pc_cls"):
            self._cipher.mode._initialization_vector = self._get_nonce(seq_num)
            self._cipher.mode._tag = mac
            decryptor = self._cipher.decryptor()
            decryptor.authenticate_additional_data(A)
            P = decryptor.update(C)
            try:
                decryptor.finalize()
            except InvalidTag:
                raise AEADTagError(P, mac)
        else:
            try:
                if (conf.crypto_valid_advanced and
                        isinstance(self._cipher, AESCCM)):
                    P = self._cipher.decrypt(self._get_nonce(seq_num), C + mac, A)  # noqa: E501
                else:
                    if (conf.crypto_valid_advanced and
                            isinstance(self, Cipher_CHACHA20_POLY1305)):
                        A += struct.pack("!H", len(C))
                    P = self._cipher.decrypt(self._get_nonce(seq_num), C + mac, A)  # noqa: E501
            except InvalidTag:
                raise AEADTagError("<unauthenticated data>", mac)
        return P, mac
Пример #5
0
    def auth_encrypt(self, P, A, seq_num=None):
        """
        Encrypt the data then prepend the explicit part of the nonce. The
        authentication tag is directly appended with the most recent crypto
        API. Additional data may be authenticated without encryption (as A).

        The 'seq_num' should never be used here, it is only a safeguard needed
        because one cipher (ChaCha20Poly1305) using TLS 1.2 logic in record.py
        actually is a _AEADCipher_TLS13 (even though others are not).
        """
        if False in six.itervalues(self.ready):
            raise CipherError(P, A)

        if hasattr(self, "pc_cls"):
            self._cipher.mode._initialization_vector = self._get_nonce()
            self._cipher.mode._tag = None
            encryptor = self._cipher.encryptor()
            encryptor.authenticate_additional_data(A)
            res = encryptor.update(P) + encryptor.finalize()
            res += encryptor.tag
        else:
            res = self._cipher.encrypt(self._get_nonce(), P, A)

        nonce_explicit = pkcs_i2osp(self.nonce_explicit,
                                    self.nonce_explicit_len)
        self._update_nonce_explicit()
        return nonce_explicit + res
Пример #6
0
    def __new__(cls,
                name,  # type: str
                bases,  # type: Tuple[type, ...]
                dct  # type: Dict[str, Any]
                ):
        # type: (...) -> Type[ASN1_Class]
        for b in bases:
            for k, v in six.iteritems(b.__dict__):
                if k not in dct and isinstance(v, ASN1Tag):
                    dct[k] = v.clone()

        rdict = {}
        for k, v in six.iteritems(dct):
            if isinstance(v, int):
                v = ASN1Tag(k, v)
                dct[k] = v
                rdict[v] = v
            elif isinstance(v, ASN1Tag):
                rdict[v] = v
        dct["__rdict__"] = rdict

        ncls = cast('Type[ASN1_Class]',
                    type.__new__(cls, name, bases, dct))
        for v in six.itervalues(ncls.__dict__):
            if isinstance(v, ASN1Tag):
                # overwrite ASN1Tag contexts, even cloned ones
                v.context = ncls
        return ncls
Пример #7
0
 def dev_from_networkname(self, network_name):
     # type: (str) -> NoReturn
     """Return interface for a given network device name."""
     try:
         return next(iface for iface in six.itervalues(self)  # type: ignore
                     if iface.network_name == network_name)
     except (StopIteration, RuntimeError):
         raise ValueError("Unknown network interface %r" % network_name)
Пример #8
0
def get_ips(v6=False):
    """Returns all available IPs matching to interfaces, using the windows system.
    Should only be used as a WinPcapy fallback."""
    res = {}
    for iface in six.itervalues(conf.ifaces):
        if v6:
            res[iface] = iface.ips[6]
        else:
            res[iface] = iface.ips[4]
    return res
Пример #9
0
 def unfilter(self):
     # type: () -> None
     """Re-enable dissection for all layers"""
     if not self.filtered:
         raise ValueError("Not filtered. Please filter first")
     for lay in six.itervalues(self.ldict):
         for cls in lay:
             cls.payload_guess = self._backup_dict[cls]
     self._backup_dict.clear()
     self.filtered = False
Пример #10
0
 def __init__(self, size=None, rndstr=None):
     if size is None:
         size = RandNumExpo(0.05)
     self.size = size
     if rndstr is None:
         rndstr = RandBin(RandNum(0, 255))
     self.rndstr = rndstr
     self._opts = list(six.itervalues(DHCPOptions))
     self._opts.remove("pad")
     self._opts.remove("end")
Пример #11
0
 def dev_from_name(self, name):
     # type: (str) -> NetworkInterface
     """Return the first network device name for a given
     device name.
     """
     try:
         return next(iface for iface in six.itervalues(self)  # type: ignore
                     if (iface.name == name or iface.description == name))
     except (StopIteration, RuntimeError):
         raise ValueError("Unknown network interface %r" % name)
Пример #12
0
 def encrypt(self, data):
     """
     Encrypt the data. Also, update the cipher iv. This is needed for SSLv3
     and TLS 1.0. For TLS 1.1/1.2, it is overwritten in TLS.post_build().
     """
     if False in six.itervalues(self.ready):
         raise CipherError(data)
     encryptor = self._cipher.encryptor()
     tmp = encryptor.update(data) + encryptor.finalize()
     self.iv = tmp[-self.block_size:]
     return tmp
Пример #13
0
 def dev_from_index(self, if_index):
     # type: (int) -> NetworkInterface
     """Return interface name from interface index"""
     try:
         if_index = int(if_index)  # Backward compatibility
         return next(iface for iface in six.itervalues(self)  # type: ignore
                     if iface.index == if_index)
     except (StopIteration, RuntimeError):
         if str(if_index) == "1":
             # Test if the loopback interface is set up
             return self.dev_from_networkname(conf.loopback_name)
         raise ValueError("Unknown network interface index %r" % if_index)
Пример #14
0
 def decrypt(self, data):
     """
     Decrypt the data. Also, update the cipher iv. This is needed for SSLv3
     and TLS 1.0. For TLS 1.1/1.2, it is overwritten in TLS.pre_dissect().
     If we lack the key, we raise a CipherError which contains the input.
     """
     if False in six.itervalues(self.ready):
         raise CipherError(data)
     decryptor = self._cipher.decryptor()
     tmp = decryptor.update(data) + decryptor.finalize()
     self.iv = data[-self.block_size:]
     return tmp
Пример #15
0
    def getfield(self, pkt, s):
        """
        If the decryption of the content did not fail with a CipherError,
        we begin a loop on the clear content in order to get as much messages
        as possible, of the type advertised in the record header. This is
        notably important for several TLS handshake implementations, which
        may for instance pack a server_hello, a certificate, a
        server_key_exchange and a server_hello_done, all in one record.
        Each parsed message may update the TLS context through their method
        .post_dissection_tls_session_update().

        If the decryption failed with a CipherError, presumably because we
        missed the session keys, we signal it by returning a
        _TLSEncryptedContent packet which simply contains the ciphered data.
        """
        tmp_len = self.length_from(pkt)
        lst = []
        ret = b""
        remain = s
        if tmp_len is not None:
            remain, ret = s[:tmp_len], s[tmp_len:]

        if remain == b"":
            if (((pkt.tls_session.tls_version or 0x0303) > 0x0200)
                    and hasattr(pkt, "type") and pkt.type == 23):
                return ret, [TLSApplicationData(data=b"")]
            elif hasattr(pkt, "type") and pkt.type == 20:
                return ret, [TLSChangeCipherSpec()]
            else:
                return ret, [Raw(load=b"")]

        if False in six.itervalues(pkt.tls_session.rcs.cipher.ready):
            return ret, _TLSEncryptedContent(remain)
        else:
            while remain:
                raw_msg = remain
                p = self.m2i(pkt, remain)
                if Padding in p:
                    pad = p[Padding]
                    remain = pad.load
                    del pad.underlayer.payload
                    if len(remain) != 0:
                        raw_msg = raw_msg[:-len(remain)]
                else:
                    remain = b""

                if isinstance(p, _GenericTLSSessionInheritance):
                    if not p.tls_session.frozen:
                        p.post_dissection_tls_session_update(raw_msg)

                lst.append(p)
            return remain + ret, lst
Пример #16
0
 def filter(self, items):
     # type: (List[Type[Packet]]) -> None
     """Disable dissection of unused layers to speed up dissection"""
     if self.filtered:
         raise ValueError("Already filtered. Please disable it first")
     for lay in six.itervalues(self.ldict):
         for cls in lay:
             if cls not in self._backup_dict:
                 self._backup_dict[cls] = cls.payload_guess[:]
                 cls.payload_guess = [
                     y for y in cls.payload_guess if y[1] in items
                 ]
     self.filtered = True
Пример #17
0
 def __init__(self, objlist=None):
     # type: (Optional[List[Type[ASN1_Object[Any]]]]) -> None
     if objlist:
         self.objlist = objlist
     else:
         self.objlist = [
             x._asn1_obj
             for x in six.itervalues(
                 ASN1_Class_UNIVERSAL.__rdict__  # type: ignore
             )
             if hasattr(x, "_asn1_obj")
         ]
     self.chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"  # noqa: E501
Пример #18
0
 def randval(self):
     # type: () -> RandChoice
     randchoices = []
     for p in six.itervalues(self.choices):
         if hasattr(p, "ASN1_root"):
             # should be ASN1_Packet class
             randchoices.append(packet.fuzz(p()))
         elif hasattr(p, "ASN1_tag"):
             if isinstance(p, type):
                 # should be (basic) ASN1F_field class
                 randchoices.append(p("dummy", None).randval())
             else:
                 # should be ASN1F_PACKET instance
                 randchoices.append(p.randval())
     return RandChoice(*randchoices)
Пример #19
0
    def auth_decrypt(self, A, C, seq_num=None, add_length=True):
        """
        Decrypt the data and authenticate the associated data (i.e. A).
        If the verification fails, an AEADTagError is raised. It is the user's
        responsibility to catch it if deemed useful. If we lack the key, we
        raise a CipherError which contains the encrypted input.

        Note that we add the TLSCiphertext length to A although we're supposed
        to add the TLSCompressed length. Fortunately, they are the same,
        but the specifications actually messed up here. :'(

        The 'add_length' switch should always be True for TLS, but we provide
        it anyway (mostly for test cases, hum).

        The 'seq_num' should never be used here, it is only a safeguard needed
        because one cipher (ChaCha20Poly1305) using TLS 1.2 logic in record.py
        actually is a _AEADCipher_TLS13 (even though others are not).
        """
        nonce_explicit_str, C, mac = (C[:self.nonce_explicit_len],
                                      C[self.nonce_explicit_len:-self.tag_len],
                                      C[-self.tag_len:])

        if False in six.itervalues(self.ready):
            raise CipherError(nonce_explicit_str, C, mac)

        self.nonce_explicit = pkcs_os2ip(nonce_explicit_str)
        if add_length:
            A += struct.pack("!H", len(C))

        if hasattr(self, "pc_cls"):
            self._cipher.mode._initialization_vector = self._get_nonce()
            self._cipher.mode._tag = mac
            decryptor = self._cipher.decryptor()
            decryptor.authenticate_additional_data(A)
            P = decryptor.update(C)
            try:
                decryptor.finalize()
            except InvalidTag:
                raise AEADTagError(nonce_explicit_str, P, mac)
        else:
            try:
                P = self._cipher.decrypt(self._get_nonce(), C + mac, A)
            except InvalidTag:
                raise AEADTagError(nonce_explicit_str,
                                   "<unauthenticated data>",
                                   mac)
        return nonce_explicit_str, P, mac
Пример #20
0
def _dissect_headers(obj, s):
    """Takes a HTTP packet as the string s, and populates the scapy layer obj
    (either HTTPResponse or HTTPRequest). Returns the first line of the
    HTTP packet, and the body
    """
    first_line, headers, body = _parse_headers_and_body(s)
    for f in obj.fields_desc:
        # We want to still parse wrongly capitalized fields
        stripped_name = _strip_header_name(f.name).lower()
        try:
            _, value = headers.pop(stripped_name)
        except KeyError:
            continue
        obj.setfieldval(f.name, value)
    if headers:
        headers = dict(six.itervalues(headers))
        obj.setfieldval('Unknown_Headers', headers)
    return first_line, body
Пример #21
0
    def auth_encrypt(self, P, A, seq_num):
        """
        Encrypt the data, and append the computed authentication code.
        The additional data for TLS 1.3 is the record header.

        Note that the cipher's authentication tag must be None when encrypting.
        """
        if False in six.itervalues(self.ready):
            raise CipherError(P, A)

        if hasattr(self, "pc_cls"):
            self._cipher.mode._tag = None
            self._cipher.mode._initialization_vector = self._get_nonce(seq_num)
            encryptor = self._cipher.encryptor()
            encryptor.authenticate_additional_data(A)
            res = encryptor.update(P) + encryptor.finalize()
            res += encryptor.tag
        else:
            if (conf.crypto_valid_advanced and
                    isinstance(self._cipher, AESCCM)):
                res = self._cipher.encrypt(self._get_nonce(seq_num), P, A)
            else:
                res = self._cipher.encrypt(self._get_nonce(seq_num), P, A)
        return res
Пример #22
0
DNSRR_DISPATCHER = {
    6: DNSRRSOA,  # RFC 1035
    15: DNSRRMX,  # RFC 1035
    33: DNSRRSRV,  # RFC 2782
    41: DNSRROPT,  # RFC 1671
    43: DNSRRDS,  # RFC 4034
    46: DNSRRRSIG,  # RFC 4034
    47: DNSRRNSEC,  # RFC 4034
    48: DNSRRDNSKEY,  # RFC 4034
    50: DNSRRNSEC3,  # RFC 5155
    51: DNSRRNSEC3PARAM,  # RFC 5155
    250: DNSRRTSIG,  # RFC 2845
    32769: DNSRRDLV,  # RFC 4431
}

DNSSEC_CLASSES = tuple(six.itervalues(DNSRR_DISPATCHER))


def isdnssecRR(obj):
    return isinstance(obj, DNSSEC_CLASSES)


class DNSRR(InheritOriginDNSStrPacket):
    name = "DNS Resource Record"
    show_indent = 0
    fields_desc = [
        DNSStrField("rrname", ""),
        ShortEnumField("type", 1, dnstypes),
        ShortEnumField("rclass", 1, dnsclasses),
        IntField("ttl", 0),
        FieldLenField("rdlen", None, length_of="rdata", fmt="H"),
Пример #23
0
 def itervalues(self):
     # type: () -> Iterator[_V]
     return six.itervalues(self.d)  # type: ignore
Пример #24
0
 def decrypt(self, data):
     if False in six.itervalues(self.ready):
         raise CipherError(data)
     self._dec_updated_with += data
     return self.decryptor.update(data)
Пример #25
0
    fields_desc = [ByteEnumField("version", 0x01, ofp_version),
                   ByteEnumField("type", 4, ofp_type),
                   ShortField("len", None),
                   IntField("xid", 0),
                   IntField("vendor", 0)]


class OFPTFeaturesRequest(_ofp_header):
    name = "OFPT_FEATURES_REQUEST"
    fields_desc = [ByteEnumField("version", 0x01, ofp_version),
                   ByteEnumField("type", 5, ofp_type),
                   ShortField("len", None),
                   IntField("xid", 0)]


ofp_action_types_flags = [v for v in six.itervalues(ofp_action_types)
                          if v != 'OFPAT_VENDOR']


class OFPTFeaturesReply(_ofp_header):
    name = "OFPT_FEATURES_REPLY"
    fields_desc = [ByteEnumField("version", 0x01, ofp_version),
                   ByteEnumField("type", 6, ofp_type),
                   ShortField("len", None),
                   IntField("xid", 0),
                   LongField("datapath_id", 0),
                   IntField("n_buffers", 0),
                   ByteField("n_tables", 1),
                   X3BytesField("pad", 0),
                   FlagsField("capabilities", 0, 32, ["FLOW_STATS",
                                                      "TABLE_STATS",
Пример #26
0
 def values(self):
     # type: () -> Any
     if self.timeout is None:
         return list(six.itervalues(self))
     t0 = time.time()
     return [v for (k, v) in six.iteritems(self.__dict__) if t0 - self._timetable[k] < self.timeout]  # noqa: E501
Пример #27
0
 def itervalues(self):
     # type: () -> Iterator[Tuple[str, Any]]
     if self.timeout is None:
         return six.itervalues(self.__dict__)  # type: ignore
     t0 = time.time()
     return (v for (k, v) in six.iteritems(self.__dict__) if t0 - self._timetable[k] < self.timeout)  # noqa: E501
Пример #28
0
    def afterglow(self,
                  src=None,  # type: Optional[Callable[[_Inner], Any]]
                  event=None,  # type: Optional[Callable[[_Inner], Any]]
                  dst=None,  # type: Optional[Callable[[_Inner], Any]]
                  **kargs  # type: Any
                  ):
        # type: (...) -> Any
        """Experimental clone attempt of http://sourceforge.net/projects/afterglow
        each datum is reduced as src -> event -> dst and the data are graphed.
        by default we have IP.src -> IP.dport -> IP.dst"""
        if src is None:
            src = lambda *x: x[0]['IP'].src
        if event is None:
            event = lambda *x: x[0]['IP'].dport
        if dst is None:
            dst = lambda *x: x[0]['IP'].dst
        sl = {}  # type: Dict[Any, Tuple[Union[float, int], List[Any]]]
        el = {}  # type: Dict[Any, Tuple[Union[float, int], List[Any]]]
        dl = {}  # type: Dict[Any, int]
        for i in self.res:
            try:
                s, e, d = src(i), event(i), dst(i)
                if s in sl:
                    n, lst = sl[s]
                    n += 1
                    if e not in lst:
                        lst.append(e)
                    sl[s] = (n, lst)
                else:
                    sl[s] = (1, [e])
                if e in el:
                    n, lst = el[e]
                    n += 1
                    if d not in lst:
                        lst.append(d)
                    el[e] = (n, lst)
                else:
                    el[e] = (1, [d])
                dl[d] = dl.get(d, 0) + 1
            except Exception:
                continue

        def minmax(x):
            # type: (Any) -> Tuple[int, int]
            m, M = reduce(lambda a, b: (min(a[0], b[0]), max(a[1], b[1])),
                          ((a, a) for a in x))
            if m == M:
                m = 0
            if M == 0:
                M = 1
            return m, M

        mins, maxs = minmax(x for x, _ in six.itervalues(sl))
        mine, maxe = minmax(x for x, _ in six.itervalues(el))
        mind, maxd = minmax(six.itervalues(dl))

        gr = 'digraph "afterglow" {\n\tedge [len=2.5];\n'

        gr += "# src nodes\n"
        for s in sl:
            n, _ = sl[s]
            n = 1 + float(n - mins) / (maxs - mins)
            gr += '"src.%s" [label = "%s", shape=box, fillcolor="#FF0000", style=filled, fixedsize=1, height=%.2f,width=%.2f];\n' % (repr(s), repr(s), n, n)  # noqa: E501
        gr += "# event nodes\n"
        for e in el:
            n, _ = el[e]
            n = 1 + float(n - mine) / (maxe - mine)
            gr += '"evt.%s" [label = "%s", shape=circle, fillcolor="#00FFFF", style=filled, fixedsize=1, height=%.2f, width=%.2f];\n' % (repr(e), repr(e), n, n)  # noqa: E501
        for d in dl:
            n = dl[d]
            n = 1 + float(n - mind) / (maxd - mind)
            gr += '"dst.%s" [label = "%s", shape=triangle, fillcolor="#0000ff", style=filled, fixedsize=1, height=%.2f, width=%.2f];\n' % (repr(d), repr(d), n, n)  # noqa: E501

        gr += "###\n"
        for s in sl:
            n, lst1 = sl[s]
            for e in lst1:
                gr += ' "src.%s" -> "evt.%s";\n' % (repr(s), repr(e))
        for e in el:
            n, lst2 = el[e]
            for d in lst2:
                gr += ' "evt.%s" -> "dst.%s";\n' % (repr(e), repr(d))

        gr += "}"
        return do_graph(gr, **kargs)
Пример #29
0
    def __new__(cls, name, bases, dct):
        # type: (str, Tuple[Any], Dict[str, Any]) -> Type[Automaton]
        cls = super(Automaton_metaclass, cls).__new__(  # type: ignore
            cls, name, bases, dct)
        cls.states = {}
        cls.recv_conditions = {}  # type: Dict[str, List[_StateWrapper]]
        cls.conditions = {}  # type: Dict[str, List[_StateWrapper]]
        cls.ioevents = {}  # type: Dict[str, List[_StateWrapper]]
        cls.timeout = {
        }  # type: Dict[str, List[Tuple[int, _StateWrapper]]] # noqa: E501
        cls.actions = {}  # type: Dict[str, List[_StateWrapper]]
        cls.initial_states = []  # type: List[_StateWrapper]
        cls.stop_states = []  # type: List[_StateWrapper]
        cls.ionames = []
        cls.iosupersockets = []

        members = {}
        classes = [cls]
        while classes:
            c = classes.pop(
                0
            )  # order is important to avoid breaking method overloading  # noqa: E501
            classes += list(c.__bases__)
            for k, v in six.iteritems(c.__dict__):
                if k not in members:
                    members[k] = v

        decorated = [
            v for v in six.itervalues(members) if hasattr(v, "atmt_type")
        ]

        for m in decorated:
            if m.atmt_type == ATMT.STATE:
                s = m.atmt_state
                cls.states[s] = m
                cls.recv_conditions[s] = []
                cls.ioevents[s] = []
                cls.conditions[s] = []
                cls.timeout[s] = []
                if m.atmt_initial:
                    cls.initial_states.append(m)
                if m.atmt_stop:
                    cls.stop_states.append(m)
            elif m.atmt_type in [
                    ATMT.CONDITION, ATMT.RECV, ATMT.TIMEOUT, ATMT.IOEVENT
            ]:  # noqa: E501
                cls.actions[m.atmt_condname] = []

        for m in decorated:
            if m.atmt_type == ATMT.CONDITION:
                cls.conditions[m.atmt_state].append(m)
            elif m.atmt_type == ATMT.RECV:
                cls.recv_conditions[m.atmt_state].append(m)
            elif m.atmt_type == ATMT.IOEVENT:
                cls.ioevents[m.atmt_state].append(m)
                cls.ionames.append(m.atmt_ioname)
                if m.atmt_as_supersocket is not None:
                    cls.iosupersockets.append(m)
            elif m.atmt_type == ATMT.TIMEOUT:
                cls.timeout[m.atmt_state].append((m.atmt_timeout, m))
            elif m.atmt_type == ATMT.ACTION:
                for co in m.atmt_cond:
                    cls.actions[co].append(m)

        for v in six.itervalues(cls.timeout):
            v.sort(key=lambda x: x[0])
            v.append((None, None))
        for v in itertools.chain(six.itervalues(cls.conditions),
                                 six.itervalues(cls.recv_conditions),
                                 six.itervalues(cls.ioevents)):
            v.sort(key=lambda x: x.atmt_prio)
        for condname, actlst in six.iteritems(cls.actions):
            actlst.sort(key=lambda x: x.atmt_cond[condname])

        for ioev in cls.iosupersockets:
            setattr(
                cls, ioev.atmt_as_supersocket,
                _ATMT_to_supersocket(ioev.atmt_as_supersocket,
                                     ioev.atmt_ioname,
                                     cast(Type["Automaton"], cls)))

        # Inject signature
        try:
            import inspect
            cls.__signature__ = inspect.signature(
                cls.parse_args)  # type: ignore  # noqa: E501
        except (ImportError, AttributeError):
            pass

        return cast(Type["Automaton"], cls)