Esempio n. 1
0
 def test_repr(self):
     """
     Returned validator has a useful `__repr__`.
     """
     v = optional(instance_of(int))
     assert (
         ("<optional validator for <instance_of validator for type "
          "<{type} 'int'>> or None>")
         .format(type=TYPE)
     ) == repr(v)
Esempio n. 2
0
    def test_fail(self, validator):
        """
        Raises `TypeError` on wrong types.
        """
        v = optional(validator)
        a = simple_attr("test")
        with pytest.raises(TypeError) as e:
            v(None, a, "42")
        assert (
            "'test' must be <{type} 'int'> (got '42' that is a <{type} "
            "'str'>).".format(type=TYPE),
            a, int, "42",

        ) == e.value.args
Esempio n. 3
0
    def test_repr(self, validator):
        """
        Returned validator has a useful `__repr__`.
        """
        v = optional(validator)

        if isinstance(validator, list):
            assert (
                ("<optional validator for _AndValidator(_validators=[{func}, "
                 "<instance_of validator for type <{type} 'int'>>]) or None>")
                .format(func=repr(always_pass), type=TYPE)
            ) == repr(v)
        else:
            assert (
                ("<optional validator for <instance_of validator for type "
                 "<{type} 'int'>> or None>")
                .format(type=TYPE)
            ) == repr(v)
Esempio n. 4
0
from collections import defaultdict
from io import BytesIO
from itertools import chain
from pathlib import Path
import sys
from typing import Dict, Iterable, Iterator, List, Union, Tuple, Optional

import attr
from attr import validators

from sortedcontainers import SortedList, SortedKeyList

from cassis.typesystem import AnnotationBase

_validator_optional_string = validators.optional(validators.instance_of(str))


class IdGenerator:
    def __init__(self, initial_id: int = 1):
        self._next_id = initial_id

    def generate_id(self) -> int:
        result = self._next_id
        self._next_id += 1
        return result


@attr.s(slots=True)
class Sofa:
    """Each CAS has one or more Subject of Analysis (SofA)"""
Esempio n. 5
0
class Requirement(object):
    name = attrib()
    vcs = attrib(default=None, validator=validators.optional(_validate_vcs))
    req = attrib(default=None,
                 validator=_optional_instance_of(BaseRequirement))
    markers = attrib(default=None)
    specifiers = attrib(validator=validators.optional(_validate_specifiers))
    index = attrib(default=None)
    editable = attrib(default=None)
    hashes = attrib(default=Factory(list), converter=list)
    extras = attrib(default=Factory(list))
    _INCLUDE_FIELDS = ("name", "markers", "index", "editable", "hashes",
                       "extras")

    @name.default
    def get_name(self):
        return self.req.name

    @property
    def requirement(self):
        return self.req.req

    @property
    def hashes_as_pip(self):
        if self.hashes:
            return "".join([HASH_STRING.format(h) for h in self.hashes])

        return ""

    @property
    def markers_as_pip(self):
        if self.markers:
            return "; {0}".format(self.markers)

        return ""

    @property
    def extras_as_pip(self):
        if self.extras:
            return "[{0}]".format(",".join(self.extras))

        return ""

    @specifiers.default
    def get_specifiers(self):
        if self.req and self.req.req.specifier:
            return _specs_to_string(self.req.req.specs)
        return

    @property
    def is_vcs(self):
        return isinstance(self.req, VCSRequirement)

    @property
    def is_file_or_url(self):
        return isinstance(self.req, FileRequirement)

    @property
    def is_named(self):
        return isinstance(self.req, NamedRequirement)

    @property
    def normalized_name(self):
        if not self.is_vcs and not self.is_file_or_url:
            return pep423_name(self.name)
        return self.name

    @classmethod
    def from_line(cls, line):
        hashes = None
        if "--hash=" in line:
            hashes = line.split(" --hash=")
            line, hashes = hashes[0], hashes[1:]
        editable = line.startswith("-e ")
        stripped_line = line.split(" ", 1)[1] if editable else line
        line, markers = _split_markers(line)
        line, extras = _strip_extras(line)
        vcs = None
        # Installable local files and installable non-vcs urls are handled
        # as files, generally speaking
        if (is_installable_file(stripped_line) or
            (is_valid_url(stripped_line) and not is_vcs(stripped_line))):
            r = FileRequirement.from_line(line)
        elif is_vcs(stripped_line):
            r = VCSRequirement.from_line(line)
            vcs = r.vcs
        else:
            name = multi_split(stripped_line, "!=<>~")[0]
            if not extras:
                name, extras = _strip_extras(name)
            r = NamedRequirement.from_line(stripped_line)
        if extras:
            extras = first(
                requirements.parse("fakepkg{0}".format(
                    _extras_to_string(extras)))).extras
            r.req.extras = extras
        if markers:
            r.req.markers = markers
        args = {
            "name": r.name,
            "vcs": vcs,
            "req": r,
            "markers": markers,
            "editable": editable,
        }
        if extras:
            args["extras"] = extras
        if hashes:
            args["hashes"] = hashes
        return cls(**args)

    @classmethod
    def from_pipfile(cls, name, indexes, pipfile):
        _pipfile = {}
        if hasattr(pipfile, "keys"):
            _pipfile = dict(pipfile).copy()
        _pipfile["version"] = _get_version(pipfile)
        vcs = first([vcs for vcs in VCS_LIST if vcs in _pipfile])
        if vcs:
            _pipfile["vcs"] = vcs
            r = VCSRequirement.from_pipfile(name, pipfile)
        elif any(key in _pipfile for key in ["path", "file", "uri"]):
            r = FileRequirement.from_pipfile(name, pipfile)
        else:
            r = NamedRequirement.from_pipfile(name, pipfile)
        args = {
            "name": r.name,
            "vcs": vcs,
            "req": r,
            "markers": PipenvMarkers.from_pipfile(name, _pipfile).line_part,
            "extras": _pipfile.get("extras"),
            "editable": _pipfile.get("editable", False),
            "index": _pipfile.get("index"),
        }
        if any(key in _pipfile for key in ["hash", "hashes"]):
            args["hashes"] = _pipfile.get("hashes", [pipfile.get("hash")])
        return cls(**args)

    def as_line(self, include_index=False, project=None):
        line = "{0}{1}{2}{3}{4}".format(
            self.req.line_part,
            self.extras_as_pip,
            self.specifiers if self.specifiers else "",
            self.markers_as_pip,
            self.hashes_as_pip,
        )
        if include_index and not (self.requirement.local_file or self.vcs):
            from .utils import prepare_pip_source_args

            if self.index:
                pip_src_args = [project.get_source(self.index)]
            else:
                pip_src_args = project.sources
            index_string = " ".join(prepare_pip_source_args(pip_src_args))
            line = "{0} {1}".format(line, index_string)
        return line

    def as_pipfile(self, include_index=False):
        good_keys = ("hashes", "extras", "markers", "editable", "version",
                     "index") + VCS_LIST
        req_dict = {
            k: v
            for k, v in attr.asdict(self, recurse=False,
                                    filter=_filter_none).items()
            if k in good_keys
        }
        name = self.name
        base_dict = {
            k: v
            for k, v in self.req.pipfile_part[name].items()
            if k not in ["req", "link"]
        }
        base_dict.update(req_dict)
        conflicting_keys = ("file", "path", "uri")
        if "file" in base_dict and any(k in base_dict
                                       for k in conflicting_keys[1:]):
            conflicts = [k for k in (conflicting_keys[1:], ) if k in base_dict]
            for k in conflicts:
                _ = base_dict.pop(k)
        if "hashes" in base_dict and len(base_dict["hashes"]) == 1:
            base_dict["hash"] = base_dict.pop("hashes")[0]
        if len(base_dict.keys()) == 1 and "version" in base_dict:
            base_dict = base_dict.get("version")
        return {name: base_dict}

    @property
    def pipfile_entry(self):
        return self.as_pipfile().copy().popitem()
Esempio n. 6
0
 def test_success(self, validator):
     """
     Nothing happens if validator succeeds.
     """
     v = optional(validator)
     v(None, simple_attr("test"), 42)
Esempio n. 7
0
def _optional_instance_of(cls):
    return validators.optional(validators.instance_of(cls))
Esempio n. 8
0
class FileRequirement(BaseRequirement):
    """File requirements for tar.gz installable files or wheels or setup.py
    containing directories."""
    path = attrib(default=None, validator=validators.optional(_validate_path))
    # : path to hit - without any of the VCS prefixes (like git+ / http+ / etc)
    uri = attrib()
    name = attrib()
    link = attrib()
    editable = attrib(default=None)
    req = attrib()
    _has_hashed_name = False
    _uri_scheme = None

    @uri.default
    def get_uri(self):
        if self.path and not self.uri:
            self._uri_scheme = "path"
            self.uri = path_to_url(os.path.abspath(self.path))

    @name.default
    def get_name(self):
        loc = self.path or self.uri
        if loc:
            self._uri_scheme = "path" if self.path else "uri"
        hashed_loc = hashlib.sha256(loc.encode("utf-8")).hexdigest()
        hash_fragment = hashed_loc[-7:]
        self._has_hashed_name = True
        return hash_fragment

    @link.default
    def get_link(self):
        target = "{0}#egg={1}".format(self.uri, self.name)
        return Link(target)

    @req.default
    def get_requirement(self):
        base = "{0}".format(self.link)
        req = first(requirements.parse(base))
        if self.editable:
            req.editable = True
        if self.link and self.link.scheme.startswith("file"):
            if self.path:
                req.path = self.path
                req.local_file = True
                self._uri_scheme = "file"
                req.uri = None
        req.link = self.link
        return req

    @property
    def is_remote_artifact(self):
        return any(
            self.link.scheme.startswith(scheme)
            for scheme in ("http", "https", "ftp", "ftps", "uri")) and (
                self.link.is_artifact
                or self.link.is_wheel) and not self.req.editable

    @classmethod
    def from_line(cls, line):
        link = None
        path = None
        editable = line.startswith("-e ")
        line = line.split(" ", 1)[1] if editable else line
        if not any([is_installable_file(line), is_valid_url(line)]):
            raise ValueError(
                "Supplied requirement is not installable: {0!r}".format(line))

        if is_valid_url(line):
            link = Link(line)
        else:
            _path = Path(line)
            link = Link(_path.absolute().as_uri())
            if _path.is_absolute() or _path.as_posix() == ".":
                path = _path.as_posix()
            else:
                path = get_converted_relative_path(line)
        arg_dict = {
            "path": path,
            "uri": link.url_without_fragment,
            "link": link,
            "editable": editable,
        }
        if link.egg_fragment:
            arg_dict["name"] = link.egg_fragment
        created = cls(**arg_dict)
        return created

    @classmethod
    def from_pipfile(cls, name, pipfile):
        uri_key = first((k for k in ["uri", "file"] if k in pipfile))
        uri = pipfile.get(uri_key, pipfile.get("path"))
        if not uri_key:
            abs_path = os.path.abspath(uri)
            uri = path_to_url(abs_path) if os.path.exists(abs_path) else None
        link = Link(uri) if uri else None
        arg_dict = {
            "name": name,
            "path": pipfile.get("path"),
            "uri": link.url_without_fragment,
            "editable": pipfile.get("editable"),
            "link": link,
        }
        return cls(**arg_dict)

    @property
    def line_part(self):
        seed = self.path or self.link.url or self.uri
        # add egg fragments to remote artifacts (valid urls only)
        if not self._has_hashed_name and self.is_remote_artifact:
            seed += "#egg={0}".format(self.name)
        editable = "-e " if self.editable else ""
        return "{0}{1}".format(editable, seed)

    @property
    def pipfile_part(self):
        pipfile_dict = {
            k: v
            for k, v in attr.asdict(self, filter=_filter_none).items()
        }
        name = pipfile_dict.pop("name")
        req = self.req
        # For local paths and remote installable artifacts (zipfiles, etc)
        if self.is_remote_artifact:
            dict_key = "file"
            # Look for uri first because file is a uri format and this is designed
            # to make sure we add file keys to the pipfile as a replacement of uri
            target_keys = [
                k for k in pipfile_dict.keys() if k in ["uri", "path"]
            ]
            pipfile_dict[dict_key] = pipfile_dict.pop(first(target_keys))
            if len(target_keys) > 1:
                _ = pipfile_dict.pop(target_keys[1])
        else:
            collisions = [
                key for key in ["path", "uri", "file"] if key in pipfile_dict
            ]
            if len(collisions) > 1:
                for k in collisions[1:]:
                    _ = pipfile_dict.pop(k)
        return {name: pipfile_dict}
Esempio n. 9
0
class Boss(object):
    _W = attrib()
    _side = attrib(validator=instance_of(type(u"")))
    _url = attrib(validator=instance_of(type(u"")))
    _appid = attrib(validator=instance_of(type(u"")))
    _versions = attrib(validator=instance_of(dict))
    _client_version = attrib(validator=instance_of(tuple))
    _reactor = attrib()
    _eventual_queue = attrib()
    _cooperator = attrib()
    _journal = attrib(validator=provides(_interfaces.IJournal))
    _tor = attrib(validator=optional(provides(_interfaces.ITorManager)))
    _timing = attrib(validator=provides(_interfaces.ITiming))
    m = MethodicalMachine()
    set_trace = getattr(m, "_setTrace",
                        lambda self, f: None)  # pragma: no cover

    def __attrs_post_init__(self):
        self._build_workers()
        self._init_other_state()

    def _build_workers(self):
        self._N = Nameplate()
        self._M = Mailbox(self._side)
        self._S = Send(self._side, self._timing)
        self._O = Order(self._side, self._timing)
        self._K = Key(self._appid, self._versions, self._side, self._timing)
        self._R = Receive(self._side, self._timing)
        self._RC = RendezvousConnector(self._url, self._appid, self._side,
                                       self._reactor, self._journal, self._tor,
                                       self._timing, self._client_version)
        self._L = Lister(self._timing)
        self._A = Allocator(self._timing)
        self._I = Input(self._timing)
        self._C = Code(self._timing)
        self._T = Terminator()
        self._D = Dilator(self._reactor, self._eventual_queue,
                          self._cooperator)

        self._N.wire(self._M, self._I, self._RC, self._T)
        self._M.wire(self._N, self._RC, self._O, self._T)
        self._S.wire(self._M)
        self._O.wire(self._K, self._R)
        self._K.wire(self, self._M, self._R)
        self._R.wire(self, self._S)
        self._RC.wire(self, self._N, self._M, self._A, self._L, self._T)
        self._L.wire(self._RC, self._I)
        self._A.wire(self._RC, self._C)
        self._I.wire(self._C, self._L)
        self._C.wire(self, self._A, self._N, self._K, self._I)
        self._T.wire(self, self._RC, self._N, self._M, self._D)
        self._D.wire(self._S, self._T)

    def _init_other_state(self):
        self._did_start_code = False
        self._next_tx_phase = 0
        self._next_rx_phase = 0
        self._rx_phases = {}  # phase -> plaintext

        self._next_rx_dilate_seqnum = 0
        self._rx_dilate_seqnums = {}  # seqnum -> plaintext

        self._result = "empty"

    # these methods are called from outside
    def start(self):
        self._RC.start()

    def _print_trace(self, old_state, input, new_state, client_name, machine,
                     file):
        if new_state:
            print(
                "%s.%s[%s].%s -> [%s]" % (client_name, machine, old_state,
                                          input, new_state),
                file=file)
        else:
            # the RendezvousConnector emits message events as if
            # they were state transitions, except that old_state
            # and new_state are empty strings. "input" is one of
            # R.connected, R.rx(type phase+side), R.tx(type
            # phase), R.lost .
            print("%s.%s.%s" % (client_name, machine, input), file=file)
        file.flush()

        def output_tracer(output):
            print(" %s.%s.%s()" % (client_name, machine, output), file=file)
            file.flush()

        return output_tracer

    def _set_trace(self, client_name, which, file):
        names = {
            "B": self,
            "N": self._N,
            "M": self._M,
            "S": self._S,
            "O": self._O,
            "K": self._K,
            "SK": self._K._SK,
            "R": self._R,
            "RC": self._RC,
            "L": self._L,
            "A": self._A,
            "I": self._I,
            "C": self._C,
            "T": self._T
        }
        for machine in which.split():
            t = (lambda old_state, input, new_state, machine=machine:
                 self._print_trace(old_state, input, new_state,
                                   client_name=client_name,
                                   machine=machine, file=file))
            names[machine].set_trace(t)
            if machine == "I":
                self._I.set_debug(t)

    # def serialize(self):
    #     raise NotImplemented

    # and these are the state-machine transition functions, which don't take
    # args
    @m.state(initial=True)
    def S0_empty(self):
        pass  # pragma: no cover

    @m.state()
    def S1_lonely(self):
        pass  # pragma: no cover

    @m.state()
    def S2_happy(self):
        pass  # pragma: no cover

    @m.state()
    def S3_closing(self):
        pass  # pragma: no cover

    @m.state(terminal=True)
    def S4_closed(self):
        pass  # pragma: no cover

    # from the Wormhole

    # input/allocate/set_code are regular methods, not state-transition
    # inputs. We expect them to be called just after initialization, while
    # we're in the S0_empty state. You must call exactly one of them, and the
    # call must happen while we're in S0_empty, which makes them good
    # candidates for being a proper @m.input, but set_code() will immediately
    # (reentrantly) cause self.got_code() to be fired, which is messy. These
    # are all passthroughs to the Code machine, so one alternative would be
    # to have Wormhole call Code.{input,allocate,set_code} instead, but that
    # would require the Wormhole to be aware of Code (whereas right now
    # Wormhole only knows about this Boss instance, and everything else is
    # hidden away).
    def input_code(self):
        if self._did_start_code:
            raise OnlyOneCodeError()
        self._did_start_code = True
        return self._C.input_code()

    def allocate_code(self, code_length):
        if self._did_start_code:
            raise OnlyOneCodeError()
        self._did_start_code = True
        wl = PGPWordList()
        self._C.allocate_code(code_length, wl)

    def set_code(self, code):
        validate_code(code)  # can raise KeyFormatError
        if self._did_start_code:
            raise OnlyOneCodeError()
        self._did_start_code = True
        self._C.set_code(code)

    def dilate(self, transit_relay_location=None, no_listen=False):
        return self._D.dilate(transit_relay_location, no_listen=no_listen)  # fires with endpoints

    @m.input()
    def send(self, plaintext):
        pass

    @m.input()
    def close(self):
        pass

    # from RendezvousConnector:
    # * "rx_welcome" is the Welcome message, which might signal an error, or
    #   our welcome_handler might signal one
    # * "rx_error" is error message from the server (probably because of
    #   something we said badly, or due to CrowdedError)
    # * "error" is when an exception happened while it tried to deliver
    #   something else
    def rx_welcome(self, welcome):
        try:
            if "error" in welcome:
                raise WelcomeError(welcome["error"])
            # TODO: it'd be nice to not call the handler when we're in
            # S3_closing or S4_closed states. I tried to implement this with
            # rx_welcome as an @input, but in the error case I'd be
            # delivering a new input (rx_error or something) while in the
            # middle of processing the rx_welcome input, and I wasn't sure
            # Automat would handle that correctly.
            self._W.got_welcome(welcome)  # TODO: let this raise WelcomeError?
        except WelcomeError as welcome_error:
            self.rx_unwelcome(welcome_error)

    @m.input()
    def rx_unwelcome(self, welcome_error):
        pass

    @m.input()
    def rx_error(self, errmsg, orig):
        pass

    @m.input()
    def error(self, err):
        pass

    # from Code (provoked by input/allocate/set_code)
    @m.input()
    def got_code(self, code):
        pass

    # Key sends (got_key, scared)
    # Receive sends (got_message, happy, got_verifier, scared)
    @m.input()
    def happy(self):
        pass

    @m.input()
    def scared(self):
        pass

    def got_message(self, phase, plaintext):
        assert isinstance(phase, type("")), type(phase)
        assert isinstance(plaintext, type(b"")), type(plaintext)
        d_mo = re.search(r'^dilate-(\d+)$', phase)
        if phase == "version":
            self._got_version(plaintext)
        elif d_mo:
            self._got_dilate(int(d_mo.group(1)), plaintext)
        elif re.search(r'^\d+$', phase):
            self._got_phase(int(phase), plaintext)
        else:
            # Ignore unrecognized phases, for forwards-compatibility. Use
            # log.err so tests will catch surprises.
            log.err(_UnknownPhaseError("received unknown phase '%s'" % phase))

    @m.input()
    def _got_version(self, plaintext):
        pass

    @m.input()
    def _got_phase(self, phase, plaintext):
        pass

    @m.input()
    def _got_dilate(self, seqnum, plaintext):
        pass

    @m.input()
    def got_key(self, key):
        pass

    @m.input()
    def got_verifier(self, verifier):
        pass

    # Terminator sends closed
    @m.input()
    def closed(self):
        pass

    @m.output()
    def do_got_code(self, code):
        self._W.got_code(code)

    @m.output()
    def process_version(self, plaintext):
        # most of this is wormhole-to-wormhole, ignored for now
        # in the future, this is how Dilation is signalled
        self._their_versions = bytes_to_dict(plaintext)
        self._D.got_wormhole_versions(self._their_versions)
        # but this part is app-to-app
        app_versions = self._their_versions.get("app_versions", {})
        self._W.got_versions(app_versions)

    @m.output()
    def S_send(self, plaintext):
        assert isinstance(plaintext, type(b"")), type(plaintext)
        phase = self._next_tx_phase
        self._next_tx_phase += 1
        self._S.send("%d" % phase, plaintext)

    @m.output()
    def close_unwelcome(self, welcome_error):
        # assert isinstance(err, WelcomeError)
        self._result = welcome_error
        self._T.close("unwelcome")

    @m.output()
    def close_error(self, errmsg, orig):
        self._result = ServerError(errmsg)
        self._T.close("errory")

    @m.output()
    def close_scared(self):
        self._result = WrongPasswordError()
        self._T.close("scary")

    @m.output()
    def close_lonely(self):
        self._result = LonelyError()
        self._T.close("lonely")

    @m.output()
    def close_happy(self):
        self._result = "happy"
        self._T.close("happy")

    @m.output()
    def W_got_key(self, key):
        self._W.got_key(key)

    @m.output()
    def D_got_key(self, key):
        self._D.got_key(key)

    @m.output()
    def W_got_verifier(self, verifier):
        self._W.got_verifier(verifier)

    @m.output()
    def W_received(self, phase, plaintext):
        assert isinstance(phase, six.integer_types), type(phase)
        # we call Wormhole.received() in strict phase order, with no gaps
        self._rx_phases[phase] = plaintext
        while self._next_rx_phase in self._rx_phases:
            self._W.received(self._rx_phases.pop(self._next_rx_phase))
            self._next_rx_phase += 1

    @m.output()
    def D_received_dilate(self, seqnum, plaintext):
        assert isinstance(seqnum, six.integer_types), type(seqnum)
        # strict phase order, no gaps
        self._rx_dilate_seqnums[seqnum] = plaintext
        while self._next_rx_dilate_seqnum in self._rx_dilate_seqnums:
            m = self._rx_dilate_seqnums.pop(self._next_rx_dilate_seqnum)
            self._D.received_dilate(m)
            self._next_rx_dilate_seqnum += 1

    @m.output()
    def W_close_with_error(self, err):
        self._result = err  # exception
        self._W.closed(self._result)

    @m.output()
    def W_closed(self):
        # result is either "happy" or a WormholeError of some sort
        self._W.closed(self._result)

    S0_empty.upon(close, enter=S3_closing, outputs=[close_lonely])
    S0_empty.upon(send, enter=S0_empty, outputs=[S_send])
    S0_empty.upon(rx_unwelcome, enter=S3_closing, outputs=[close_unwelcome])
    S0_empty.upon(got_code, enter=S1_lonely, outputs=[do_got_code])
    S0_empty.upon(rx_error, enter=S3_closing, outputs=[close_error])
    S0_empty.upon(error, enter=S4_closed, outputs=[W_close_with_error])

    S1_lonely.upon(rx_unwelcome, enter=S3_closing, outputs=[close_unwelcome])
    S1_lonely.upon(happy, enter=S2_happy, outputs=[])
    S1_lonely.upon(scared, enter=S3_closing, outputs=[close_scared])
    S1_lonely.upon(close, enter=S3_closing, outputs=[close_lonely])
    S1_lonely.upon(send, enter=S1_lonely, outputs=[S_send])
    S1_lonely.upon(got_key, enter=S1_lonely, outputs=[W_got_key, D_got_key])
    S1_lonely.upon(rx_error, enter=S3_closing, outputs=[close_error])
    S1_lonely.upon(error, enter=S4_closed, outputs=[W_close_with_error])

    S2_happy.upon(rx_unwelcome, enter=S3_closing, outputs=[close_unwelcome])
    S2_happy.upon(got_verifier, enter=S2_happy, outputs=[W_got_verifier])
    S2_happy.upon(_got_phase, enter=S2_happy, outputs=[W_received])
    S2_happy.upon(_got_version, enter=S2_happy, outputs=[process_version])
    S2_happy.upon(_got_dilate, enter=S2_happy, outputs=[D_received_dilate])
    S2_happy.upon(scared, enter=S3_closing, outputs=[close_scared])
    S2_happy.upon(close, enter=S3_closing, outputs=[close_happy])
    S2_happy.upon(send, enter=S2_happy, outputs=[S_send])
    S2_happy.upon(rx_error, enter=S3_closing, outputs=[close_error])
    S2_happy.upon(error, enter=S4_closed, outputs=[W_close_with_error])

    S3_closing.upon(rx_unwelcome, enter=S3_closing, outputs=[])
    S3_closing.upon(rx_error, enter=S3_closing, outputs=[])
    S3_closing.upon(got_verifier, enter=S3_closing, outputs=[])
    S3_closing.upon(_got_phase, enter=S3_closing, outputs=[])
    S3_closing.upon(_got_version, enter=S3_closing, outputs=[])
    S3_closing.upon(_got_dilate, enter=S3_closing, outputs=[])
    S3_closing.upon(happy, enter=S3_closing, outputs=[])
    S3_closing.upon(scared, enter=S3_closing, outputs=[])
    S3_closing.upon(close, enter=S3_closing, outputs=[])
    S3_closing.upon(send, enter=S3_closing, outputs=[])
    S3_closing.upon(closed, enter=S4_closed, outputs=[W_closed])
    S3_closing.upon(error, enter=S4_closed, outputs=[W_close_with_error])

    S4_closed.upon(rx_unwelcome, enter=S4_closed, outputs=[])
    S4_closed.upon(got_verifier, enter=S4_closed, outputs=[])
    S4_closed.upon(_got_phase, enter=S4_closed, outputs=[])
    S4_closed.upon(_got_version, enter=S4_closed, outputs=[])
    S4_closed.upon(_got_dilate, enter=S4_closed, outputs=[])
    S4_closed.upon(happy, enter=S4_closed, outputs=[])
    S4_closed.upon(scared, enter=S4_closed, outputs=[])
    S4_closed.upon(close, enter=S4_closed, outputs=[])
    S4_closed.upon(send, enter=S4_closed, outputs=[])
    S4_closed.upon(error, enter=S4_closed, outputs=[])
class FovSubsetData:
    """
    A dataclass-like object keeping tabs of data for a subset of data
    in a given FOV. Used, for example, when a FOV has both labeled
    and unlabeled cells. In this case, a FovData instance will contain
    two FovSubsetData instances.

    Parameters:
    :param pathlib.Path results_file: an .npz file generated by CaImAn
    :param Optional[bool] with_labeling: Controls whether the data
    was taken with a second channel containing morphological data. True
    means that this subset points to the labeled data, False means that
    this subset points to the unlabeled data, and None means that there was
    no colabeling involved with this data.
    :param pathlib.Path colabel_img: Path to the tif file containing the
    images of the cell as seen in the colabeled channel.
    """

    results_file = attr.ib(validator=instance_of(pathlib.Path))
    with_labeling = attr.ib(validator=optional(instance_of(bool)))
    colabel_img = attr.ib(default=None)
    tif_file = attr.ib(init=False)
    colabel_file = attr.ib(init=False)
    colabel_stack = attr.ib(init=False)
    dff = attr.ib(init=False, repr=False)
    indices = attr.ib(init=False, repr=False)
    loaded = attr.ib(init=False)

    def load_data(self):
        """ Main class method to populate its different
        attributes with the data and proper files """
        self.tif_file = self._find_tif_file()
        if self.with_labeling is not None:
            self.colabel_file = self._find_colabeled_file()
            self.colabel_stack = self._load_colabeled_img()
        self.dff, self.indices = self._populate_dff_data()
        self.loaded = True

    def _find_tif_file(self):
        """
        Finds and returns the associated tif file. Returns None if
        doesn't exist.
        """
        name = self.results_file.name[:-12] + ".tif"
        try:
            tif_file = next(self.results_file.parent.glob(name))
            return tif_file
        except StopIteration:
            print(f"Tif not found for {name}")
            return None

    def _find_colabeled_file(self) -> Union[pathlib.Path, None]:
        """
        Finds and returns the colabeled file. Returns None if
        doesn't exist.
        """
        name = self.results_file.name[:-11] + "colabeled_idx.npy"
        try:
            colabel_file = next(self.results_file.parent.glob(name))
            return colabel_file
        except StopIteration:
            return None

    def _load_colabeled_img(self) -> np.ndarray:
        """
        Loads a tif file containing the parallel data
        of the colabeled cells to memory.
        """
        return tifffile.imread(str(self.colabel_img))

    def _populate_dff_data(self):
        """
        Using the different found filenames, load the dF/F data into
        memory. If a subset of the rows should be loaded (since we're
        working with labeled data) the indices of the relevant
        rows are also returned.
        """
        all_data = np.load(self.results_file)["F_dff"]
        if self.with_labeling is None:
            return all_data, np.arange(all_data.shape[0])

        indices = np.load(self.colabel_file)
        if self.with_labeling:
            return all_data[indices], indices
        if not self.with_labeling:
            all_indices = np.arange(all_data.shape[0])
            remaining_indices = np.delete(all_indices, indices)
            remaining_traces = all_data[remaining_indices]
            return remaining_traces, remaining_indices
Esempio n. 11
0
class ConfigSchema(Schema):

    NAME: ClassVar[str] = "config"

    # model config

    entities: Dict[str, EntitySchema] = attr.ib(
        validator=non_empty,
        metadata={
            "help":
            "The entity types. The ID with which they are "
            "referenced by the relation types is the key they "
            "have in this dict."
        },
    )
    relations: List[RelationSchema] = attr.ib(
        validator=non_empty,
        metadata={
            "help":
            "The relation types. The ID with which they will be "
            "referenced in the edge lists is their index in this "
            "list."
        },
    )
    dimension: int = attr.ib(
        validator=positive,
        metadata={
            "help": "The dimension of the real space the embedding live "
            "in."
        },
    )
    init_scale: float = attr.ib(
        default=1e-3,
        validator=positive,
        metadata={
            "help":
            "If no initial embeddings are provided, they are "
            "generated by sampling each dimension from a "
            "centered normal distribution having this standard "
            "deviation. (For performance reasons, sampling isn't "
            "fully independent.)"
        },
    )
    max_norm: Optional[float] = attr.ib(
        default=None,
        validator=optional(positive),
        metadata={
            "help":
            "If set, rescale the embeddings if their norm "
            "exceeds this value."
        },
    )
    global_emb: bool = attr.ib(
        default=True,
        metadata={
            "help":
            "If enabled, add to each embedding a vector that is "
            "common to all the entities of a certain type. This "
            "vector is learned during training."
        },
    )
    comparator: str = attr.ib(
        default="cos",
        metadata={
            "help":
            "How the embeddings of the two sides of an edge "
            "(after having already undergone some processing) "
            "are compared to each other to produce a score."
        },
    )
    bias: bool = attr.ib(
        default=False,
        metadata={
            "help":
            "If enabled, withhold the first dimension of the "
            "embeddings from the comparator and instead use it "
            "as a bias, adding back to the score. Makes sense "
            "for logistic and softmax loss functions."
        },
    )
    loss_fn: str = attr.ib(
        default="ranking",
        metadata={
            "help":
            "How the scores of positive edges and their "
            "corresponding negatives are evaluated."
        },
    )
    margin: float = attr.ib(
        default=0.1,
        metadata={
            "help":
            "When using ranking loss, this value controls the "
            "minimum separation between positive and negative "
            "scores, below which a (linear) loss is incured."
        },
    )

    # data config

    entity_path: str = attr.ib(
        metadata={
            "help": "The path of the directory containing entity count "
            "files."
        })
    edge_paths: List[str] = attr.ib(
        metadata={
            "help":
            "A list of paths to directories containing "
            "(partitioned) edgelists. Typically a single path is "
            "provided."
        })
    checkpoint_path: str = attr.ib(
        metadata={
            "help":
            "The path to the directory where checkpoints (and "
            "thus the output) will be written to. If checkpoints "
            "are found in it, training will resume from them."
        })
    init_path: Optional[str] = attr.ib(
        default=None,
        metadata={
            "help":
            "If set, it must be a path to a directory that "
            "contains initial values for the embeddings of all "
            "the entities of some types."
        },
    )
    checkpoint_preservation_interval: Optional[int] = attr.ib(
        default=None,
        metadata={
            "help":
            "If set, every so many epochs a snapshot of the "
            "checkpoint will be archived. The snapshot will be "
            "located inside a `epoch_{N}` sub-directory of the "
            "checkpoint directory, and will contain symbolic "
            "links to the original checkpoint files, which will "
            "not be cleaned-up as it would normally happen."
        },
    )

    # training config

    num_epochs: int = attr.ib(
        default=1,
        validator=non_negative,
        metadata={
            "help":
            "The number of times the training loop iterates over "
            "all the edges."
        },
    )
    num_edge_chunks: Optional[int] = attr.ib(
        default=None,
        validator=optional(positive),
        metadata={
            "help":
            "The number of equally-sized parts each bucket will "
            "be split into. Training will first proceed over all "
            "the first chunks of all buckets, then over all the "
            "second chunks, and so on. A higher value allows "
            "better mixing of partitions, at the cost of more "
            "time spent on I/O. If unset, will be automatically "
            "calculated so that no chunk has more than "
            "max_edges_per_chunk edges."
        },
    )
    max_edges_per_chunk: int = attr.ib(
        default=1_000_000_000,  # Each edge having 3 int64s, this is 12GB.
        validator=positive,
        metadata={
            "help":
            "The maximum number of edges that each edge chunk "
            "should contain if the number of edge chunks is left "
            "unspecified and has to be automatically figured "
            "out. Each edge takes up at least 12 bytes (3 "
            "int64s), more if using featurized entities."
        },
    )
    bucket_order: BucketOrder = attr.ib(
        default=BucketOrder.INSIDE_OUT,
        metadata={"help": "The order in which to iterate over the buckets."},
    )
    workers: Optional[int] = attr.ib(
        default=None,
        validator=optional(positive),
        metadata={
            "help":
            'The number of worker processes for "Hogwild!" '
            "training. If not given, set to CPU count."
        },
    )
    batch_size: int = attr.ib(
        default=1000,
        validator=positive,
        metadata={"help": "The number of edges per batch."},
    )
    num_batch_negs: int = attr.ib(
        default=50,
        validator=non_negative,
        metadata={
            "help":
            "The number of negatives sampled from the batch, per "
            "positive edge."
        },
    )
    num_uniform_negs: int = attr.ib(
        default=50,
        validator=non_negative,
        metadata={
            "help":
            "The number of negatives uniformly sampled from the "
            "currently active partition, per positive edge."
        },
    )
    disable_lhs_negs: bool = attr.ib(
        default=False,
        metadata={"help": "Disable negative sampling on the left-hand side."},
    )
    disable_rhs_negs: bool = attr.ib(
        default=False,
        metadata={"help": "Disable negative sampling on the right-hand side."},
    )
    lr: float = attr.ib(
        default=1e-2,
        validator=non_negative,
        metadata={"help": "The learning rate for the optimizer."},
    )
    relation_lr: Optional[float] = attr.ib(
        default=None,
        validator=optional(non_negative),
        metadata={
            "help":
            "If set, the learning rate for the optimizer"
            "for relations. Otherwise, `lr' is used."
        },
    )
    eval_fraction: float = attr.ib(
        default=0.05,
        validator=non_negative,
        metadata={
            "help":
            "The fraction of edges withheld from training and "
            "used to track evaluation metrics during training."
        },
    )
    eval_num_batch_negs: Optional[int] = attr.ib(
        default=1000,
        validator=optional(non_negative),
        metadata={
            "help":
            "If set, overrides the number of negatives "
            "per positive edge sampled from the batch during the "
            "evaluation steps that occur before and after each "
            "training step."
        },
    )
    eval_num_uniform_negs: Optional[int] = attr.ib(
        default=1000,
        validator=optional(non_negative),
        metadata={
            "help":
            "If set, overrides the number of "
            "uniformly-sampled negatives per positive edge "
            "during the evaluation steps that occur before and "
            "after each training step."
        },
    )

    # expert options

    background_io: bool = attr.ib(
        default=False,
        metadata={
            "help":
            "Whether to do load/save in a background process. "
            "DEPRECATED."
        },
    )
    verbose: int = attr.ib(
        default=0,
        validator=non_negative,
        metadata={"help": "The verbosity level of logging, currently 0 or 1."},
    )
    hogwild_delay: float = attr.ib(
        default=2,
        validator=non_negative,
        metadata={
            "help":
            "The number of seconds by which to delay the start "
            'of all "Hogwild!" processes except the first one, '
            "on the first epoch."
        },
    )
    dynamic_relations: bool = attr.ib(
        default=False,
        metadata={
            "help":
            "If enabled, activates the dynamic relation mode, in "
            "which case, there must be a single relation type in "
            "the config (whose parameters will apply to all "
            "dynamic relations types) and there must be a file "
            "called dynamic_rel_count.txt in the entity path that "
            "contains the number of dynamic relations. In this "
            "mode, batches will contain edges of multiple "
            "relation types and negatives will be sampled "
            "differently."
        },
    )

    # distributed training config options

    num_machines: int = attr.ib(
        default=1,
        validator=positive,
        metadata={"help": "The number of machines for distributed training."},
    )
    num_partition_servers: int = attr.ib(
        default=-1,
        metadata={
            "help":
            "If -1, use trainer as partition servers. If 0, "
            "don't use partition servers (instead, swap "
            "partitions through disk). If >1, then that number "
            "of partition servers must be started manually."
        },
    )
    distributed_init_method: Optional[str] = attr.ib(
        default=None,
        metadata={
            "help":
            "A URI defining how to synchronize all the workers "
            "of a distributed run. Must start with a scheme "
            "(e.g., file:// or tcp://) supported by PyTorch."
        },
    )
    distributed_tree_init_order: bool = attr.ib(
        default=True,
        metadata={
            "help":
            "If enabled, then distributed training can occur on "
            "a bucket only if at least one of its partitions was "
            "already trained on before in the same round (or if "
            "one of its partitions is 0, for bootstrapping)."
        },
    )

    num_gpus: int = attr.ib(
        default=0,
        metadata={
            "help":
            "Number of GPUs to use for GPU training. "
            "Experimental: Not yet supported."
        },
    )
    num_groups_for_partition_server: int = attr.ib(
        default=16,
        metadata={
            "help":
            "Number of td.distributed 'groups' to use. Setting "
            "this to a value around 16 typically increases "
            "communication bandwidth."
        },
    )
    half_precision: bool = attr.ib(
        default=False,
        metadata={"help": "Use half-precision training (GPU ONLY)"})

    # Additional global validation.

    def __attrs_post_init__(self) -> None:
        for rel_id, rel_config in enumerate(self.relations):
            if rel_config.lhs not in self.entities:
                raise ValueError("Relation type %s (#%d) has an unknown "
                                 "left-hand side entity type %s" %
                                 (rel_config.name, rel_id, rel_config.lhs))
            if rel_config.rhs not in self.entities:
                raise ValueError("Relation type %s (#%d) has an unknown "
                                 "right-hand side entity type %s" %
                                 (rel_config.name, rel_id, rel_config.rhs))
        if self.dynamic_relations:
            if len(self.relations) != 1:
                raise ValueError("When dynamic relations are in use only one "
                                 "relation type must be defined.")
        # TODO Check that all partitioned entity types have the same number of
        # partitions
        # TODO Check that the batch size is a multiple of the batch negative number
        if self.loss_fn == "logistic" and self.comparator == "cos":
            logger.warning(
                "You have logistic loss and cosine distance. Are you sure?")

        if self.disable_lhs_negs and self.disable_rhs_negs:
            raise ValueError("Cannot disable negative sampling on both sides.")

        if self.background_io:
            logger.warning(
                "`background_io` is deprecated and will have no effect.")

    def entity_dimension(self, entity_type: str) -> int:
        """get the dimension for an entity"""
        return self.entities[entity_type].dimension or self.dimension
Esempio n. 12
0
class MultiKeyring(Keyring):
    """Public class for Multi Keyring.

    .. versionadded:: 1.5.0

    :param Keyring generator: Generator keyring used to generate data encryption key (optional)
    :param List[Keyring] children: List of keyrings used to encrypt the data encryption key (optional)
    :raises EncryptKeyError: if encryption of data key fails for any reason
    """

    generator = attr.ib(default=None, validator=optional(instance_of(Keyring)))
    children = attr.ib(
        default=attr.Factory(tuple),
        validator=optional(
            deep_iterable(member_validator=instance_of(Keyring))))

    def __attrs_post_init__(self):
        # type: () -> None
        """Prepares initial values not handled by attrs."""
        neither_generator_nor_children = self.generator is None and not self.children
        if neither_generator_nor_children:
            raise TypeError(
                "At least one of generator or children must be provided")

        _generator = (self.generator, ) if self.generator is not None else ()
        self._decryption_keyrings = list(
            itertools.chain(_generator, self.children))

    def on_encrypt(self, encryption_materials):
        # type: (EncryptionMaterials) -> EncryptionMaterials
        """Generate a data key using generator keyring
        and encrypt it using any available wrapping key in any child keyring.

        :param EncryptionMaterials encryption_materials: Encryption materials for keyring to modify.
        :returns: Optionally modified encryption materials.
        :rtype: EncryptionMaterials
        :raises EncryptKeyError: if unable to encrypt data key.
        """
        # Check if generator keyring is not provided and data key is not generated
        if self.generator is None and encryption_materials.data_encryption_key is None:
            raise EncryptKeyError(
                "Generator keyring not provided "
                "and encryption materials do not already contain a plaintext data key."
            )

        new_materials = encryption_materials

        # Call on_encrypt on the generator keyring if it is provided
        if self.generator is not None:
            new_materials = self.generator.on_encrypt(
                encryption_materials=new_materials)

        # Check if data key is generated
        if new_materials.data_encryption_key is None:
            raise GenerateKeyError("Unable to generate data encryption key.")

        # Call on_encrypt on all other keyrings
        for keyring in self.children:
            new_materials = keyring.on_encrypt(
                encryption_materials=new_materials)

        return new_materials

    def on_decrypt(self, decryption_materials, encrypted_data_keys):
        # type: (DecryptionMaterials, Iterable[EncryptedDataKey]) -> DecryptionMaterials
        """Attempt to decrypt the encrypted data keys.

        :param DecryptionMaterials decryption_materials: Decryption materials for keyring to modify.
        :param List[EncryptedDataKey] encrypted_data_keys: List of encrypted data keys.
        :returns: Optionally modified decryption materials.
        :rtype: DecryptionMaterials
        """
        # Call on_decrypt on all keyrings till decryption is successful
        new_materials = decryption_materials
        for keyring in self._decryption_keyrings:
            if new_materials.data_encryption_key is not None:
                return new_materials

            new_materials = keyring.on_decrypt(
                decryption_materials=new_materials,
                encrypted_data_keys=encrypted_data_keys)

        return new_materials
Esempio n. 13
0
 def test_success_with_none(self, validator):
     """
     Nothing happens if None.
     """
     v = optional(validator)
     v(None, simple_attr("test"), None)
Esempio n. 14
0
 def test_success_with_none(self):
     """
     Nothing happens if None.
     """
     v = optional(instance_of(int))
     v(None, simple_attr("test"), None)
Esempio n. 15
0
 def test_success(self, validator):
     """
     Nothing happens if validator succeeds.
     """
     v = optional(validator)
     v(None, simple_attr("test"), 42)
Esempio n. 16
0
def optional_instance_of(cls):
    return validators.optional(validators.instance_of(cls))
Esempio n. 17
0
class PipelineAction(_ConfigStructure):
    """CodePipeline action definition.

    :param provider: Action provider name
        (must be a valid CodePipeline action provider name)
    :param inputs: Names of CodePipeline inputs to collect
    :param outputs: Names of CodePipeline outputs to emit
    :param configuration: Additional string-string map of configuration values to provide in
        CodePipeline action definition
    :param image: Docker image to use with CodeBuild
        (only used for CodeBuild provider actions)
    :param environment_type: CodeBuild environment type name
        (only used for CodeBuild provider actions)
        (if not provided, we will attempt to guess based on the image name)
    :param buildspec: Location of CodeBuild buildspec in source
        (only used for CodeBuild provider actions)
        (in-line buildspec definition not supported)
    :param compute_type: CodeBuild compute type name
        (only used for CodeBuild provider actions)
        (default: ``BUILD_GENERAL1_SMALL``)
    :param env: Mapping of environment variables to set in action environment
    :param run_order: CodePipeline action run order
    """

    provider: str = attr.ib(validator=instance_of(str))
    inputs: Set[str] = attr.ib(
        default=attr.Factory(set), validator=optional(deep_iterable(member_validator=instance_of(str)))
    )
    outputs: Set[str] = attr.ib(
        default=attr.Factory(set), validator=optional(deep_iterable(member_validator=instance_of(str)))
    )
    configuration: Dict[str, str] = attr.ib(default=attr.Factory(dict), validator=optional(_STRING_STRING_MAP))
    image: Optional[str] = attr.ib(default=None, validator=optional(instance_of(str)))
    environment_type: Optional[str] = attr.ib(default=None, validator=optional(instance_of(str)))
    buildspec: Optional[str] = attr.ib(default=None, validator=optional(instance_of(str)))
    compute_type: str = attr.ib(default="BUILD_GENERAL1_SMALL", validator=optional(instance_of(str)))
    env: Dict[str, str] = attr.ib(default=attr.Factory(dict), validator=optional(_STRING_STRING_MAP))
    run_order: int = attr.ib(default=1, validator=optional(instance_of(int)))

    @run_order.validator
    def _check_run_order(self, attribute, value):  # pylint: disable=unused-argument,no-self-use
        """Verify that ``run_order`` value is valid."""
        if value < 1:
            raise ValueError("PipelineAction run_order value must be >= 1")

    @image.validator
    def _check_image(self, attribute, value):  # pylint: disable=unused-argument
        """Verify that ``image`` is set if provider type ``CodeBuild`` is used."""
        if self.provider == "CodeBuild" and value is None:
            raise ValueError('image must be defined for actions of type "CodeBuild"')

    @buildspec.validator
    def _check_buildspec(self, attribute, value):  # pylint: disable=unused-argument
        """Verify that ``buildspec`` is set if provider type ``CodeBuild`` is used."""
        if self.provider == "CodeBuild" and value is None:
            raise ValueError('buildspec must be defined for actions of type "CodeBuild"')

    def __attrs_post_init__(self):
        """Set default values for ``environment_type``."""
        if self.provider == "CodeBuild" and self.environment_type is None:
            if "windows" in self.image.lower():
                self.environment_type = "WINDOWS_CONTAINER"
            else:
                self.environment_type = "LINUX_CONTAINER"
Esempio n. 18
0
class RendezvousConnector(object):
    _url = attrib(validator=instance_of(type(u"")))
    _appid = attrib(validator=instance_of(type(u"")))
    _side = attrib(validator=instance_of(type(u"")))
    _reactor = attrib()
    _journal = attrib(validator=provides(_interfaces.IJournal))
    _tor = attrib(validator=optional(provides(_interfaces.ITorManager)))
    _timing = attrib(validator=provides(_interfaces.ITiming))

    def __attrs_post_init__(self):
        self._trace = None
        self._ws = None
        f = WSFactory(self, self._url)
        f.setProtocolOptions(autoPingInterval=60, autoPingTimeout=600)
        p = urlparse(self._url)
        ep = self._make_endpoint(p.hostname, p.port or 80)
        # TODO: change/wrap ClientService to fail if the first attempt fails
        self._connector = internet.ClientService(ep, f)
        self._debug_record_inbound_f = None

    def set_trace(self, f):
        self._trace = f

    def _debug(self, what):
        if self._trace:
            self._trace(old_state="", input=what, new_state="")

    def _make_endpoint(self, hostname, port):
        if self._tor:
            # TODO: when we enable TLS, maybe add tls=True here
            return self._tor.stream_via(hostname, port)
        return endpoints.HostnameEndpoint(self._reactor, hostname, port)

    def wire(self, boss, nameplate, mailbox, allocator, lister, terminator):
        self._B = _interfaces.IBoss(boss)
        self._N = _interfaces.INameplate(nameplate)
        self._M = _interfaces.IMailbox(mailbox)
        self._A = _interfaces.IAllocator(allocator)
        self._L = _interfaces.ILister(lister)
        self._T = _interfaces.ITerminator(terminator)

    # from Boss
    def start(self):
        self._connector.startService()

    # from Mailbox
    def tx_claim(self, nameplate):
        self._tx("claim", nameplate=nameplate)

    def tx_open(self, mailbox):
        self._tx("open", mailbox=mailbox)

    def tx_add(self, phase, body):
        assert isinstance(phase, type("")), type(phase)
        assert isinstance(body, type(b"")), type(body)
        self._tx("add", phase=phase, body=bytes_to_hexstr(body))

    def tx_release(self, nameplate):
        self._tx("release", nameplate=nameplate)

    def tx_close(self, mailbox, mood):
        self._tx("close", mailbox=mailbox, mood=mood)

    def stop(self):
        # ClientService.stopService is defined to "Stop attempting to
        # reconnect and close any existing connections"
        d = defer.maybeDeferred(self._connector.stopService)
        d.addErrback(log.err)  # TODO: deliver error upstairs?
        d.addBoth(self._stopped)

    # from Lister
    def tx_list(self):
        self._tx("list")

    # from Code
    def tx_allocate(self):
        self._tx("allocate")

    # from our WSClient (the WebSocket protocol)
    def ws_open(self, proto):
        self._debug("R.connected")
        self._ws = proto
        try:
            self._tx("bind", appid=self._appid, side=self._side)
            self._N.connected()
            self._M.connected()
            self._L.connected()
            self._A.connected()
        except Exception as e:
            self._B.error(e)
            raise
        self._debug("R.connected finished notifications")

    def ws_message(self, payload):
        msg = bytes_to_dict(payload)
        if msg["type"] != "ack":
            self._debug("R.rx(%s %s%s)" % (
                msg["type"],
                msg.get("phase", ""),
                "[mine]" if msg.get("side", "") == self._side else "",
            ))

        self._timing.add("ws_receive", _side=self._side, message=msg)
        if self._debug_record_inbound_f:
            self._debug_record_inbound_f(msg)
        mtype = msg["type"]
        meth = getattr(self, "_response_handle_" + mtype, None)
        if not meth:
            # make tests fail, but real application will ignore it
            log.err(
                errors._UnknownMessageTypeError(
                    "Unknown inbound message type %r" % (msg, )))
            return
        try:
            return meth(msg)
        except Exception as e:
            log.err(e)
            self._B.error(e)
            raise

    def ws_close(self, wasClean, code, reason):
        self._debug("R.lost")
        self._ws = None
        self._N.lost()
        self._M.lost()
        self._L.lost()
        self._A.lost()

    # internal
    def _stopped(self, res):
        self._T.stopped()

    def _tx(self, mtype, **kwargs):
        assert self._ws
        # msgid is used by misc/dump-timing.py to correlate our sends with
        # their receives, and vice versa. They are also correlated with the
        # ACKs we get back from the server (which we otherwise ignore). There
        # are so few messages, 16 bits is enough to be mostly-unique.
        kwargs["id"] = bytes_to_hexstr(os.urandom(2))
        kwargs["type"] = mtype
        self._debug("R.tx(%s %s)" % (mtype.upper(), kwargs.get("phase", "")))
        payload = dict_to_bytes(kwargs)
        self._timing.add("ws_send", _side=self._side, **kwargs)
        self._ws.sendMessage(payload, False)

    def _response_handle_allocated(self, msg):
        nameplate = msg["nameplate"]
        assert isinstance(nameplate, type("")), type(nameplate)
        self._A.rx_allocated(nameplate)

    def _response_handle_nameplates(self, msg):
        # we get list of {id: ID}, with maybe more attributes in the future
        nameplates = msg["nameplates"]
        assert isinstance(nameplates, list), type(nameplates)
        nids = set()
        for n in nameplates:
            assert isinstance(n, dict), type(n)
            nameplate_id = n["id"]
            assert isinstance(nameplate_id, type("")), type(nameplate_id)
            nids.add(nameplate_id)
        # deliver a set of nameplate ids
        self._L.rx_nameplates(nids)

    def _response_handle_ack(self, msg):
        pass

    def _response_handle_error(self, msg):
        # the server sent us a type=error. Most cases are due to our mistakes
        # (malformed protocol messages, sending things in the wrong order),
        # but it can also result from CrowdedError (more than two clients
        # using the same channel).
        err = msg["error"]
        orig = msg["orig"]
        self._B.rx_error(err, orig)

    def _response_handle_welcome(self, msg):
        self._B.rx_welcome(msg["welcome"])

    def _response_handle_claimed(self, msg):
        mailbox = msg["mailbox"]
        assert isinstance(mailbox, type("")), type(mailbox)
        self._N.rx_claimed(mailbox)

    def _response_handle_message(self, msg):
        side = msg["side"]
        phase = msg["phase"]
        assert isinstance(phase, type("")), type(phase)
        body = hexstr_to_bytes(msg["body"])  # bytes
        self._M.rx_message(side, phase, body)

    def _response_handle_released(self, msg):
        self._N.rx_released()

    def _response_handle_closed(self, msg):
        self._M.rx_closed()
Esempio n. 19
0
class AbstractCrossSituationalLearner(AbstractTemplateLearnerNew, ABC):
    """
    An Abstract Implementation of the Cross Situation Learning Model

    This learner aims to learn via storing all possible meanings and narrowing down to one meaning
    by calculating association scores and probability based off those association scores for each
    utterance situation pair. It does so be associating all words to certain meanings. For new words
    meanings that are not associated strongly to another word already are associated evenly. For
    words encountered before, words are associated more strongly to meanings encountered with that
    word before and less strongly to newer meanings. Lastly, very familiar word meaning pairs are
    associated together only, these would be words generally considered lexicalized. Once
    associations are made a probability for each word meaning pair being correct is calculated.
    Finally if the probability is high enough the word is lexicalized. More information can be
    found here: https://onlinelibrary.wiley.com/doi/full/10.1111/j.1551-6709.2010.01104.x
    """
    @attrs(slots=True, eq=False, frozen=True)
    class Hypothesis:
        pattern_template: PerceptionGraphTemplate = attrib(
            validator=instance_of(PerceptionGraphTemplate))
        association_score: float = attrib(validator=instance_of(float),
                                          default=0)
        probability: float = attrib(validator=in_(Range.open(0, 1)), default=0)
        observation_count: int = attrib(default=1)

    _ontology: Ontology = attrib(validator=instance_of(Ontology), kw_only=True)
    _observation_num = attrib(init=False, default=0)
    _surface_template_to_concept: Dict[SurfaceTemplate,
                                       Concept] = attrib(init=False,
                                                         default=Factory(dict))
    _concept_to_surface_template: Dict[Concept, SurfaceTemplate] = attrib(
        init=False, default=Factory(dict))
    _concept_to_hypotheses: ImmutableDict[
        Concept,
        ImmutableSet["AbstractCrossSituationalLearner.Hypothesis"]] = attrib(
            init=False, default=Factory(dict))

    # Learner Internal Values
    _smoothing_parameter: float = attrib(validator=in_(
        Range.greater_than(0.0)),
                                         kw_only=True)
    """
    This smoothing factor is added to the scores of all hypotheses
    when forming a probability distribution over hypotheses.
    This should be a small value, at most 0.1 and possibly much less.
    See section 3.3 of the Cross-Situational paper.
    """
    _expected_number_of_meanings: float = attrib(validator=in_(
        Range.greater_than(0.0)),
                                                 kw_only=True)
    _graph_match_confirmation_threshold: float = attrib(default=0.8,
                                                        kw_only=True)
    _lexicon_entry_threshold: float = attrib(default=0.8, kw_only=True)
    _minimum_observation_amount: int = attrib(default=5, kw_only=True)

    _concepts_in_utterance: ImmutableSet[Concept] = attrib(
        init=False, default=ImmutableSet)
    _updated_hypotheses: Dict[Concept,
                              ImmutableSet[Hypothesis]] = attrib(init=False,
                                                                 factory=dict)

    # Corresponds to the dummy word from the paper
    _dummy_concept: Concept = attrib(init=False)

    # Debug Values
    _debug_callback: Optional[DebugCallableType] = attrib(default=None)
    _graph_logger: Optional[GraphLogger] = attrib(validator=optional(
        instance_of(GraphLogger)),
                                                  default=None)

    @_dummy_concept.default
    def _init_dummy_concept(self):
        return self._new_concept("_cross_situational_dummy_concept")

    def _pre_learning_step(
        self, language_perception_semantic_alignment:
        LanguagePerceptionSemanticAlignment
    ) -> None:
        # Figure out what "words" (concepts) appear in the utterance.
        concepts_in_utterance = []
        for other_bound_surface_template in self._candidate_templates(
                language_perception_semantic_alignment):
            # We have seen this template before and already have a concept for it
            if (other_bound_surface_template.surface_template
                    in self._surface_template_to_concept):
                concept = self._surface_template_to_concept[
                    other_bound_surface_template.surface_template]
            # Otherwise, make a new concept for it
            else:
                concept = self._new_concept(
                    debug_string=other_bound_surface_template.surface_template.
                    to_short_string())
            self._surface_template_to_concept[
                other_bound_surface_template.surface_template] = concept
            self._concept_to_surface_template[
                concept] = other_bound_surface_template.surface_template
            concepts_in_utterance.append(concept)
        self._concepts_in_utterance = immutableset(concepts_in_utterance)

        # We only need to make a shallow copy of our old hypotheses
        # because the values of self._concept_to_hypotheses are immutable.
        self._updated_hypotheses = dict(self._concept_to_hypotheses)

    def _learning_step(
        self,
        language_perception_semantic_alignment:
        LanguagePerceptionSemanticAlignment,
        bound_surface_template: SurfaceTemplateBoundToSemanticNodes,
    ) -> None:
        """
        Try to learn the semantics of a `SurfaceTemplate` given the assumption
        that its argument slots (if any) are bound to objects according to
        *bound_surface_template*.

        For example, "try to learn the meaning of 'red' given the language 'red car'
        and an alignment of 'car' to particular perceptions in the perception graph.
        """
        # Generate all possible meanings from the Graph
        meanings_from_perception = immutableset(
            self._hypotheses_from_perception(
                language_perception_semantic_alignment,
                bound_surface_template))
        meanings_to_pattern_template: Mapping[
            PerceptionGraph, PerceptionGraphTemplate] = immutabledict(
                (meaning,
                 PerceptionGraphTemplate.from_graph(meaning, immutabledict()))
                for meaning in meanings_from_perception)

        # We check for meanings that are described by lexicalized concepts
        # and don't try to learn those lexicalized concepts further.
        # jac: Not mentioned in the part of the paper I read. New?
        concepts_to_remove: Set[Concept] = set()

        def check_and_remove_meaning(
            other_concept: Concept,
            hypothesis: "AbstractCrossSituationalLearner.Hypothesis",
            *,
            ontology: Ontology,
        ) -> None:
            match = compute_match_ratio(
                hypothesis.pattern_template,
                language_perception_semantic_alignment.
                perception_semantic_alignment.perception_graph,
                ontology=ontology,
            )
            if match and match.matching_subgraph:
                for meaning in meanings_from_perception:
                    if match.matching_subgraph.check_isomorphism(
                            meanings_to_pattern_template[meaning].graph_pattern
                    ):
                        concepts_to_remove.add(other_concept)

        for (other_concept, hypotheses) in self._concept_to_hypotheses.items():
            for hypothesis in hypotheses:
                if hypothesis.probability > self._lexicon_entry_threshold:
                    check_and_remove_meaning(other_concept,
                                             hypothesis,
                                             ontology=self._ontology)

        # We have seen this template before and already have a concept for it
        # So we attempt to verify our already picked concept
        if bound_surface_template.surface_template in self._surface_template_to_concept:
            # We don't directly associate surface templates with perceptions.
            # Instead we mediate the relationship with "concept" objects.
            # These don't matter now, but the split might be helpful in the future
            # when we might have multiple ways of expressing the same idea.
            concept = self._surface_template_to_concept[
                bound_surface_template.surface_template]
        else:
            concept = self._new_concept(debug_string=bound_surface_template.
                                        surface_template.to_short_string())
        self._surface_template_to_concept[
            bound_surface_template.surface_template] = concept
        self._concept_to_surface_template[
            concept] = bound_surface_template.surface_template

        concepts_after_preprocessing = immutableset([
            concept for concept in self._concepts_in_utterance
            if concept not in concepts_to_remove
            # TODO Does it make sense to include a dummy concept/"word"? The paper has one so I
            #  am including it for now.
        ] + [self._dummy_concept])

        # Step 0. Update priors for any meanings as-yet unobserved.

        # Step 1. Compute alignment probabilities (pp. 1029)
        # We have an identified "word" (concept) from U(t)
        # and a collection of meanings from the scene S(t).
        # We now want to calculate the alignment probabilities,
        # which will be used to update this concept's association scores, assoc(w|m, U(t), S(t)),
        # and meaning probabilities, p(m|w).
        alignment_probabilities = self._get_alignment_probabilities(
            concepts_after_preprocessing, meanings_from_perception)

        # We have an identified "word" (concept) from U(t)
        # and a collection of meanings from the scene S(t).
        # We now want to update p(.|w), which means calculating the probabilities.
        new_hypotheses = self._updated_meaning_probabilities(
            concept,
            meanings_from_perception,
            meanings_to_pattern_template,
            alignment_probabilities,
        )

        # Finally, update our hypotheses for this concept
        self._updated_hypotheses[concept] = new_hypotheses

    def _post_learning_step(
        self, language_perception_semantic_alignment:
        LanguagePerceptionSemanticAlignment
    ) -> None:
        # Finish updating hypotheses
        # We have to do this as a separate step
        # so that we can update our hypotheses for each concept
        # independently of the hypotheses for the other concepts,
        # as in the algorithm described by the paper.
        self._concept_to_hypotheses = immutabledict(self._updated_hypotheses)
        self._updated_hypotheses.clear()

    def _get_alignment_probabilities(
        self, concepts: Iterable[Concept],
        meanings: ImmutableSet[PerceptionGraph]
    ) -> ImmutableDict[Concept, ImmutableDict[PerceptionGraph, float]]:
        """
        Compute the concept-(concrete meaning) alignment probabilities for a given word
        as defined by the paper below:

        a(m|c, U(t), S(t)) = (p^(t-1)(m|c)) / sum(for c' in (U^(t) union {d}))

        where c and m are given concept and meanings, lambda is a smoothing factor, M is all
        meanings encountered, beta is an upper bound on the expected number of meaning types.
        https://onlinelibrary.wiley.com/doi/full/10.1111/j.1551-6709.2010.01104.x (3)
        """
        def meaning_probability(meaning: PerceptionGraph,
                                concept: Concept) -> float:
            """
            Return the meaning probability p^(t-1)(m|c).
            """
            # If we've already observed this concept before,
            if concept in self._concept_to_hypotheses:
                # And if we've already observed this meaning before,
                maybe_ratio_with_preexisting_hypothesis = self._find_similar_hypothesis(
                    meaning, self._concept_to_hypotheses[concept])
                if maybe_ratio_with_preexisting_hypothesis:
                    # return the prior probability.
                    _, preexisting_hypothesis = maybe_ratio_with_preexisting_hypothesis
                    return preexisting_hypothesis.probability
                # Otherwise, if we have observed this concept before
                # but not paired with a perception like this meaning,
                # it is assigned zero probability.
                # Is this correct?
                else:
                    return 0.0
            # If we haven't observed this concept before,
            # its prior probability is evenly split among all the observed meanings in this perception.
            else:
                return 1.0 / len(meanings)

        meaning_to_concept_to_alignment_probability: Dict[
            PerceptionGraph, ImmutableDict[Concept, float]] = dict()
        for meaning in iter(meanings):
            # We want to calculate the alignment probabilities for each concept against this meaning.
            # First, we compute the prior meaning probabilities p(m|c),
            # the probability that the concept c means m for each meaning m observed in the scene.
            concept_to_meaning_probability: Mapping[Concept,
                                                    float] = immutabledict({
                                                        concept:
                                                        meaning_probability(
                                                            meaning, concept)
                                                        for concept in concepts
                                                    })
            total_probability_mass: float = sum(
                concept_to_meaning_probability.values())

            # We use these to calculate the alignment probabilities a(c|m, U(t), S(t)).
            meaning_to_concept_to_alignment_probability[
                meaning] = immutabledict({
                    concept: meaning_probability_ / total_probability_mass
                    for concept, meaning_probability_ in
                    concept_to_meaning_probability.items()
                })

        # Restructure meaning_to_concept_to_alignment_probability
        # to get a map concept_to_meaning_to_alignment_probability.
        return immutabledict([
            (concept, immutabledict([(meaning, alignment_probability)]))
            for meaning, concept_to_alignment_probability in
            meaning_to_concept_to_alignment_probability.items() for concept,
            alignment_probability in concept_to_alignment_probability.items()
        ])

    def _updated_meaning_probabilities(
        self,
        concept: Concept,
        meanings: Iterable[PerceptionGraph],
        meaning_to_pattern: Mapping[PerceptionGraph, PerceptionGraphTemplate],
        alignment_probabilities: Mapping[Concept, Mapping[PerceptionGraph,
                                                          float]],
    ) -> ImmutableSet["AbstractCrossSituationalLearner.Hypothesis"]:
        """
        Update all concept-(abstract meaning) probabilities for a given word
        as defined by the paper below:

        p(m|c) = (assoc(m, c) + lambda) / (sum(for m' in M)(assoc(c, m)) + (beta * lambda))

        where c and m are given concept and meanings, lambda is a smoothing factor, M is all
        meanings encountered, beta is an upper bound on the expected number of meaning types.
        https://onlinelibrary.wiley.com/doi/full/10.1111/j.1551-6709.2010.01104.x (3)
        """
        old_hypotheses = self._concept_to_hypotheses.get(
            concept, immutableset())

        # First we calculate the new association scores for each observed meaning.
        # If a meaning was not observed this instance, we don't change its association score at all.
        updated_hypotheses: Set[
            "AbstractCrossSituationalLearner.Hypothesis"] = set()
        hypothesis_updates: List[
            "AbstractCrossSituationalLearner.Hypothesis"] = []
        for meaning in meanings:
            # We use a placeholder probability to keep the hypothesis constructor happy.
            # We are going to fix up this probability later.
            placeholder_probability = 0.5
            # First, check if we've observed this meaning before.
            ratio_similar_hypothesis_pair = self._find_similar_hypothesis(
                meaning, old_hypotheses)
            if ratio_similar_hypothesis_pair is not None:
                ratio, similar_hypothesis = ratio_similar_hypothesis_pair

                # If we *have* observed this meaning before,
                # we need to update the existing hypothesis for it.
                if ratio.match_ratio > self._graph_match_confirmation_threshold:
                    # Mark the old hypothesis as updated
                    # so we don't include both the old and new hypothesis in our output.
                    updated_hypotheses.add(similar_hypothesis)
                    new_association_score = (
                        similar_hypothesis.association_score +
                        alignment_probabilities[concept][meaning])
                    new_observation_count = similar_hypothesis.observation_count + 1
                    new_hypothesis = AbstractCrossSituationalLearner.Hypothesis(
                        pattern_template=similar_hypothesis.pattern_template,
                        association_score=new_association_score,
                        probability=placeholder_probability,
                        observation_count=new_observation_count,
                    )
                    hypothesis_updates.append(new_hypothesis)
                    continue

            # If we *haven't* observed this meaning before,
            # we need to create a new hypothesis for it.
            new_hypothesis = AbstractCrossSituationalLearner.Hypothesis(
                pattern_template=meaning_to_pattern[meaning],
                association_score=alignment_probabilities[concept][meaning],
                probability=placeholder_probability,
                observation_count=1,
            )
            hypothesis_updates.append(new_hypothesis)

        # Now we calculate the updated meaning probabilities p(m|w).
        total_association_score = sum(hypothesis.association_score
                                      for hypothesis in hypothesis_updates)
        smoothing_term = self._expected_number_of_meanings * self._smoothing_parameter
        return immutableset(
            chain(
                # Include old hypotheses that weren't updated
                [
                    old_hypothesis for old_hypothesis in old_hypotheses
                    if old_hypothesis not in updated_hypotheses
                ],
                # Include new and updated hypotheses
                [
                    evolve(
                        hypothesis,
                        # Replace the placeholder meaning probability with the true meaning probability,
                        # calculated using the association scores and smoothing term.
                        probability=(hypothesis.association_score +
                                     self._smoothing_parameter) /
                        (total_association_score + smoothing_term),
                    ) for hypothesis in hypothesis_updates
                ],
            ))

    def _find_similar_hypothesis(
        self,
        new_meaning: PerceptionGraph,
        candidates: Iterable["AbstractCrossSituationalLearner.Hypothesis"],
    ) -> Optional[Tuple[PartialMatchRatio,
                        "AbstractCrossSituationalLearner.Hypothesis"]]:
        """
        Finds the hypothesis in candidates most similar to new_meaning and returns it
        together with the match ratio.

        Returns None if no candidate can be found that is sufficiently similar to new_meaning. A candidate is
        sufficiently similar if and only if its match ratio with new_meaning is at least
        _graph_match_confirmation_threshold.
        """
        candidates_iter = iter(candidates)
        match = None
        while match is None:
            try:
                existing_hypothesis = next(candidates_iter)
            except StopIteration:
                return None

            try:
                match = compute_match_ratio(
                    existing_hypothesis.pattern_template,
                    new_meaning,
                    ontology=self._ontology,
                )
            except RuntimeError:
                # Occurs when no matches of the pattern are found in the graph. This seems to
                # to indicate some full matches and some matches with no intersection at all
                pass

        for candidate in candidates:
            try:
                new_match = compute_match_ratio(candidate.pattern_template,
                                                new_meaning,
                                                ontology=self._ontology)
            except RuntimeError:
                # Occurs when no matches of the pattern are found in the graph. This seems to
                # to indicate some full matches and some matches with no intersection at all
                new_match = None
            if new_match and new_match.match_ratio > match.match_ratio:
                match = new_match
                existing_hypothesis = candidate
        if (match.match_ratio >= self._graph_match_confirmation_threshold
                and match.matching_subgraph and existing_hypothesis):
            return match, existing_hypothesis
        else:
            return None

    def templates_for_concept(
            self, concept: Concept) -> AbstractSet[SurfaceTemplate]:
        if concept in self._concept_to_surface_template:
            return immutableset([self._concept_to_surface_template[concept]])
        else:
            return immutableset()

    def concepts_to_patterns(self) -> Dict[Concept, PerceptionGraphPattern]:
        def argmax(hypotheses):
            # TODO is this key correct? what IS our "best hypothesis"?
            return max(
                hypotheses,
                key=lambda hypothesis: (
                    hypothesis.probability,
                    hypothesis.association_score,
                ),
            )

        return {
            concept: argmax(hypotheses).pattern_template.graph_pattern
            for concept, hypotheses in self._concept_to_hypotheses.items()
        }

    @abstractmethod
    def _new_concept(self, debug_string: str) -> Concept:
        """
        Create a new `Concept` of the appropriate type with the given *debug_string*.
        """

    @abstractmethod
    def _hypotheses_from_perception(
        self,
        learning_state: LanguagePerceptionSemanticAlignment,
        bound_surface_template: SurfaceTemplateBoundToSemanticNodes,
    ) -> Iterable[PerceptionGraph]:
        """
        Get a hypothesis for the meaning of *surface_template* from a given *learning_state*.
        """

    def _primary_templates(
            self) -> Iterable[Tuple[Concept, PerceptionGraphTemplate, float]]:
        return (
            (concept, hypothesis.pattern_template, hypothesis.probability)
            for (concept, hypotheses) in self._concept_to_hypotheses.items()
            # We are confident in a hypothesis if it's above our _lexicon_entry_threshold
            # and we've seen this concept our _minimum_observation_amount
            for hypothesis in hypotheses
            if hypothesis.observation_count >= self._minimum_observation_amount
            and hypothesis.probability >= self._lexicon_entry_threshold)

    def _fallback_templates(
            self) -> Iterable[Tuple[Concept, PerceptionGraphTemplate, float]]:
        # Alternate hypotheses either below our _lexicon_entry_threshold or our _minimum_observation_amount
        return (
            (concept, hypothesis.pattern_template, hypothesis.probability)
            for (concept, hypotheses) in self._concept_to_hypotheses.items()
            for hypothesis in sorted(
                hypotheses,
                key=lambda hypothesis: hypothesis.probability,
                reverse=True)
            if hypothesis.observation_count < self._minimum_observation_amount
            or hypothesis.probability < self._lexicon_entry_threshold)

    def _match_template(
        self,
        *,
        concept: Concept,
        pattern: PerceptionGraphTemplate,
        perception_graph: PerceptionGraph,
    ) -> Iterable[Tuple[PerceptionGraphPatternMatch, SemanticNode]]:
        """
        Try to match our model of the semantics to the perception graph
        """
        partial_match = compute_match_ratio(
            pattern,
            perception_graph,
            ontology=self._ontology,
            graph_logger=self._graph_logger,
            debug_callback=self._debug_callback,
        )

        if (partial_match.match_ratio >=
                self._graph_match_confirmation_threshold
                and partial_match.matching_subgraph):
            # if there is a match, which is above our minimum match ratio
            # Use that pattern to try and find a match in the scene
            # There should be one
            # TODO: This currently means we match to the graph multiple times. Reduce this?
            matcher = partial_match.matching_subgraph.matcher(
                perception_graph,
                match_mode=MatchMode.NON_OBJECT,
                debug_callback=self._debug_callback,
            )
            found_match = False
            for match in matcher.matches(use_lookahead_pruning=True):
                found_match = True
                semantic_node_for_match = pattern_match_to_semantic_node(
                    concept=concept, pattern=pattern, match=match)
                yield match, semantic_node_for_match
            # We raise an error if we find a partial match but don't manage to match it to the scene
            if not found_match:
                raise RuntimeError(
                    f"Partial Match found for {concept} below match ratio however pattern "
                    f"subgraph was unable to match to perception graph.\n"
                    f"Partial Match: {partial_match}\n"
                    f"Perception Graph: {perception_graph}")
Esempio n. 20
0
 def test_success_with_none(self):
     """
     Nothing happens if None.
     """
     v = optional(instance_of(int))
     v(None, simple_attr("test"), None)
Esempio n. 21
0
class Manager(object):
    _S = attrib(validator=provides(ISend), repr=False)
    _my_side = attrib(validator=instance_of(type(u"")))
    _transit_key = attrib(validator=instance_of(bytes), repr=False)
    _transit_relay_location = attrib(validator=optional(instance_of(str)))
    _reactor = attrib(repr=False)
    _eventual_queue = attrib(repr=False)
    _cooperator = attrib(repr=False)
    _no_listen = attrib(default=False)
    _tor = None  # TODO
    _timing = None  # TODO
    _next_subchannel_id = None  # initialized in choose_role

    m = MethodicalMachine()
    set_trace = getattr(m, "_setTrace",
                        lambda self, f: None)  # pragma: no cover

    def __attrs_post_init__(self):
        self._got_versions_d = Deferred()

        self._my_role = None  # determined upon rx_PLEASE

        self._connection = None
        self._made_first_connection = False
        self._first_connected = OneShotObserver(self._eventual_queue)
        self._stopped = OneShotObserver(self._eventual_queue)
        self._host_addr = _WormholeAddress()

        self._next_dilation_phase = 0

        # I kept getting confused about which methods were for inbound data
        # (and thus flow-control methods go "out") and which were for
        # outbound data (with flow-control going "in"), so I split them up
        # into separate pieces.
        self._inbound = Inbound(self, self._host_addr)
        self._outbound = Outbound(self, self._cooperator)  # from us to peer

    def set_listener_endpoint(self, listener_endpoint):
        self._inbound.set_listener_endpoint(listener_endpoint)

    def set_subchannel_zero(self, scid0, sc0):
        self._inbound.set_subchannel_zero(scid0, sc0)

    def when_first_connected(self):
        return self._first_connected.when_fired()

    def when_stopped(self):
        return self._stopped.when_fired()

    def send_dilation_phase(self, **fields):
        dilation_phase = self._next_dilation_phase
        self._next_dilation_phase += 1
        self._S.send("dilate-%d" % dilation_phase, dict_to_bytes(fields))

    def send_hints(self, hints):  # from Connector
        self.send_dilation_phase(type="connection-hints", hints=hints)

    # forward inbound-ish things to _Inbound

    def subchannel_pauseProducing(self, sc):
        self._inbound.subchannel_pauseProducing(sc)

    def subchannel_resumeProducing(self, sc):
        self._inbound.subchannel_resumeProducing(sc)

    def subchannel_stopProducing(self, sc):
        self._inbound.subchannel_stopProducing(sc)

    # forward outbound-ish things to _Outbound
    def subchannel_registerProducer(self, sc, producer, streaming):
        self._outbound.subchannel_registerProducer(sc, producer, streaming)

    def subchannel_unregisterProducer(self, sc):
        self._outbound.subchannel_unregisterProducer(sc)

    def send_open(self, scid):
        assert isinstance(scid, bytes)
        self._queue_and_send(Open, scid)

    def send_data(self, scid, data):
        assert isinstance(scid, bytes)
        self._queue_and_send(Data, scid, data)

    def send_close(self, scid):
        assert isinstance(scid, bytes)
        self._queue_and_send(Close, scid)

    def _queue_and_send(self, record_type, *args):
        r = self._outbound.build_record(record_type, *args)
        # Outbound owns the send_record() pipe, so that it can stall new
        # writes after a new connection is made until after all queued
        # messages are written (to preserve ordering).
        self._outbound.queue_and_send_record(r)  # may trigger pauseProducing

    def subchannel_closed(self, scid, sc):
        # let everyone clean up. This happens just after we delivered
        # connectionLost to the Protocol, except for the control channel,
        # which might get connectionLost later after they use ep.connect.
        # TODO: is this inversion a problem?
        self._inbound.subchannel_closed(scid, sc)
        self._outbound.subchannel_closed(scid, sc)

    # our Connector calls these

    def connector_connection_made(self, c):
        self.connection_made()  # state machine update
        self._connection = c
        self._inbound.use_connection(c)
        self._outbound.use_connection(c)  # does c.registerProducer
        if not self._made_first_connection:
            self._made_first_connection = True
            self._first_connected.fire(None)
        pass

    def connector_connection_lost(self):
        self._stop_using_connection()
        if self._my_role is LEADER:
            self.connection_lost_leader()  # state machine
        else:
            self.connection_lost_follower()

    def _stop_using_connection(self):
        # the connection is already lost by this point
        self._connection = None
        self._inbound.stop_using_connection()
        self._outbound.stop_using_connection()  # does c.unregisterProducer

    # from our active Connection

    def got_record(self, r):
        # records with sequence numbers: always ack, ignore old ones
        if isinstance(r, (Open, Data, Close)):
            self.send_ack(r.seqnum)  # always ack, even for old ones
            if self._inbound.is_record_old(r):
                return
            self._inbound.update_ack_watermark(r.seqnum)
            if isinstance(r, Open):
                self._inbound.handle_open(r.scid)
            elif isinstance(r, Data):
                self._inbound.handle_data(r.scid, r.data)
            else:  # isinstance(r, Close)
                self._inbound.handle_close(r.scid)
            return
        if isinstance(r, KCM):
            log.err(UnexpectedKCM())
        elif isinstance(r, Ping):
            self.handle_ping(r.ping_id)
        elif isinstance(r, Pong):
            self.handle_pong(r.ping_id)
        elif isinstance(r, Ack):
            self._outbound.handle_ack(r.resp_seqnum)  # retire queued messages
        else:
            log.err(UnknownMessageType("{}".format(r)))

    # pings, pongs, and acks are not queued
    def send_ping(self, ping_id):
        self._outbound.send_if_connected(Ping(ping_id))

    def send_pong(self, ping_id):
        self._outbound.send_if_connected(Pong(ping_id))

    def send_ack(self, resp_seqnum):
        self._outbound.send_if_connected(Ack(resp_seqnum))

    def handle_ping(self, ping_id):
        self.send_pong(ping_id)

    def handle_pong(self, ping_id):
        # TODO: update is-alive timer
        pass

    # subchannel maintenance
    def allocate_subchannel_id(self):
        scid_num = self._next_subchannel_id
        self._next_subchannel_id += 2
        return to_be4(scid_num)

    # state machine

    # We are born WANTING after the local app calls w.dilate(). We start
    # CONNECTING when we receive PLEASE from the remote side

    def start(self):
        self.send_please()

    def send_please(self):
        self.send_dilation_phase(type="please", side=self._my_side)

    @m.state(initial=True)
    def WANTING(self):
        pass  # pragma: no cover

    @m.state()
    def CONNECTING(self):
        pass  # pragma: no cover

    @m.state()
    def CONNECTED(self):
        pass  # pragma: no cover

    @m.state()
    def FLUSHING(self):
        pass  # pragma: no cover

    @m.state()
    def ABANDONING(self):
        pass  # pragma: no cover

    @m.state()
    def LONELY(self):
        pass  # pragma: no cover

    @m.state()
    def STOPPING(self):
        pass  # pragma: no cover

    @m.state(terminal=True)
    def STOPPED(self):
        pass  # pragma: no cover

    @m.input()
    def rx_PLEASE(self, message):
        pass  # pragma: no cover

    @m.input()  # only sent by Follower
    def rx_HINTS(self, hint_message):
        pass  # pragma: no cover

    @m.input()  # only Leader sends RECONNECT, so only Follower receives it
    def rx_RECONNECT(self):
        pass  # pragma: no cover

    @m.input()  # only Follower sends RECONNECTING, so only Leader receives it
    def rx_RECONNECTING(self):
        pass  # pragma: no cover

    # Connector gives us connection_made()
    @m.input()
    def connection_made(self):
        pass  # pragma: no cover

    # our connection_lost() fires connection_lost_leader or
    # connection_lost_follower depending upon our role. If either side sees a
    # problem with the connection (timeouts, bad authentication) then they
    # just drop it and let connection_lost() handle the cleanup.
    @m.input()
    def connection_lost_leader(self):
        pass  # pragma: no cover

    @m.input()
    def connection_lost_follower(self):
        pass

    @m.input()
    def stop(self):
        pass  # pragma: no cover

    @m.output()
    def choose_role(self, message):
        their_side = message["side"]
        if self._my_side > their_side:
            self._my_role = LEADER
            # scid 0 is reserved for the control channel. the leader uses odd
            # numbers starting with 1
            self._next_subchannel_id = 1
        elif their_side > self._my_side:
            self._my_role = FOLLOWER
            # the follower uses even numbers starting with 2
            self._next_subchannel_id = 2
        else:
            raise ValueError("their side shouldn't be equal: reflection?")

    # these Outputs behave differently for the Leader vs the Follower

    @m.output()
    def start_connecting_ignore_message(self, message):
        del message  # ignored
        return self._start_connecting()

    @m.output()
    def start_connecting(self):
        self._start_connecting()

    def _start_connecting(self):
        assert self._my_role is not None
        self._connector = Connector(
            self._transit_key,
            self._transit_relay_location,
            self,
            self._reactor,
            self._eventual_queue,
            self._no_listen,
            self._tor,
            self._timing,
            self._my_side,  # needed for relay handshake
            self._my_role)
        self._connector.start()

    @m.output()
    def send_reconnect(self):
        self.send_dilation_phase(type="reconnect")  # TODO: generation number?

    @m.output()
    def send_reconnecting(self):
        self.send_dilation_phase(type="reconnecting")  # TODO: generation?

    @m.output()
    def use_hints(self, hint_message):
        hint_objs = filter(
            lambda h: h,  # ignore None, unrecognizable
            [parse_hint(hs) for hs in hint_message["hints"]])
        hint_objs = list(hint_objs)
        self._connector.got_hints(hint_objs)

    @m.output()
    def stop_connecting(self):
        self._connector.stop()

    @m.output()
    def abandon_connection(self):
        # we think we're still connected, but the Leader disagrees. Or we've
        # been told to shut down.
        self._connection.disconnect()  # let connection_lost do cleanup

    @m.output()
    def notify_stopped(self):
        self._stopped.fire(None)

    # we start CONNECTING when we get rx_PLEASE
    WANTING.upon(rx_PLEASE,
                 enter=CONNECTING,
                 outputs=[choose_role, start_connecting_ignore_message])

    CONNECTING.upon(connection_made, enter=CONNECTED, outputs=[])

    # Leader
    CONNECTED.upon(connection_lost_leader,
                   enter=FLUSHING,
                   outputs=[send_reconnect])
    FLUSHING.upon(rx_RECONNECTING,
                  enter=CONNECTING,
                  outputs=[start_connecting])

    # Follower
    # if we notice a lost connection, just wait for the Leader to notice too
    CONNECTED.upon(connection_lost_follower, enter=LONELY, outputs=[])
    LONELY.upon(rx_RECONNECT,
                enter=CONNECTING,
                outputs=[send_reconnecting, start_connecting])
    # but if they notice it first, abandon our (seemingly functional)
    # connection, then tell them that we're ready to try again
    CONNECTED.upon(rx_RECONNECT,
                   enter=ABANDONING,
                   outputs=[abandon_connection])
    ABANDONING.upon(connection_lost_follower,
                    enter=CONNECTING,
                    outputs=[send_reconnecting, start_connecting])
    # and if they notice a problem while we're still connecting, abandon our
    # incomplete attempt and try again. in this case we don't have to wait
    # for a connection to finish shutdown
    CONNECTING.upon(
        rx_RECONNECT,
        enter=CONNECTING,
        outputs=[stop_connecting, send_reconnecting, start_connecting])

    # rx_HINTS never changes state, they're just accepted or ignored
    WANTING.upon(rx_HINTS, enter=WANTING, outputs=[])  # too early
    CONNECTING.upon(rx_HINTS, enter=CONNECTING, outputs=[use_hints])
    CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[])  # too late, ignore
    FLUSHING.upon(rx_HINTS, enter=FLUSHING, outputs=[])  # stale, ignore
    LONELY.upon(rx_HINTS, enter=LONELY, outputs=[])  # stale, ignore
    ABANDONING.upon(rx_HINTS, enter=ABANDONING, outputs=[])  # shouldn't happen
    STOPPING.upon(rx_HINTS, enter=STOPPING, outputs=[])

    WANTING.upon(stop, enter=STOPPED, outputs=[notify_stopped])
    CONNECTING.upon(stop,
                    enter=STOPPED,
                    outputs=[stop_connecting, notify_stopped])
    CONNECTED.upon(stop, enter=STOPPING, outputs=[abandon_connection])
    ABANDONING.upon(stop, enter=STOPPING, outputs=[])
    FLUSHING.upon(stop, enter=STOPPED, outputs=[notify_stopped])
    LONELY.upon(stop, enter=STOPPED, outputs=[notify_stopped])
    STOPPING.upon(connection_lost_leader,
                  enter=STOPPED,
                  outputs=[notify_stopped])
    STOPPING.upon(connection_lost_follower,
                  enter=STOPPED,
                  outputs=[notify_stopped])
Esempio n. 22
0
class Connector(object):
    """I manage a single generation of connection.

    The Manager creates one of me at a time, whenever it wants a connection
    (which is always, once w.dilate() has been called and we know the remote
    end can dilate, and is expressed by the Manager calling my .start()
    method). I am discarded when my established connection is lost (and if we
    still want to be connected, a new generation is started and a new
    Connector is created). I am also discarded if we stop wanting to be
    connected (which the Manager expresses by calling my .stop() method).

    I manage the race between multiple connections for a specific generation
    of the dilated connection.

    I send connection hints when my InboundConnectionFactory yields addresses
    (self.listener_ready), and I initiate outbond connections (with
    OutboundConnectionFactory) as I receive connection hints from my peer
    (self.got_hints). Both factories use my build_protocol() method to create
    connection.DilatedConnectionProtocol instances. I track these protocol
    instances until one finishes negotiation and wins the race. I then shut
    down the others, remember the winner as self._winning_connection, and
    deliver the winner to manager.connector_connection_made(c).

    When an active connection is lost, we call manager.connector_connection_lost,
    allowing the manager to decide whether it wants to start a new generation
    or not.
    """

    _dilation_key = attrib(validator=instance_of(type(b"")))
    _transit_relay_location = attrib(validator=optional(instance_of(type(u""))))
    _manager = attrib(validator=provides(IDilationManager))
    _reactor = attrib()
    _eventual_queue = attrib()
    _no_listen = attrib(validator=instance_of(bool))
    _tor = attrib()
    _timing = attrib()
    _side = attrib(validator=instance_of(type(u"")))
    # was self._side = bytes_to_hexstr(os.urandom(8)) # unicode
    _role = attrib()

    m = MethodicalMachine()
    set_trace = getattr(m, "_setTrace", lambda self, f: None)  # pragma: no cover

    RELAY_DELAY = 2.0

    def __attrs_post_init__(self):
        if self._transit_relay_location:
            # TODO: allow multiple hints for a single relay
            relay_hint = parse_hint_argv(self._transit_relay_location)
            relay = RelayV1Hint(hints=(relay_hint,))
            self._transit_relays = [relay]
        else:
            self._transit_relays = []
        self._listeners = set()  # IListeningPorts that can be stopped
        self._pending_connectors = set()  # Deferreds that can be cancelled
        self._pending_connections = EmptyableSet(
            _eventual_queue=self._eventual_queue)  # Protocols to be stopped
        self._contenders = set()  # viable connections
        self._winning_connection = None
        self._timing = self._timing or DebugTiming()
        self._timing.add("transit")

    # this describes what our Connector can do, for the initial advertisement
    @classmethod
    def get_connection_abilities(klass):
        return [{"type": "direct-tcp-v1"},
                {"type": "relay-v1"},
                ]

    def build_protocol(self, addr, description):
        # encryption: let's use Noise NNpsk0 (or maybe NNpsk2). That uses
        # ephemeral keys plus a pre-shared symmetric key (the Transit key), a
        # different one for each potential connection.
        noise = build_noise()
        noise.set_psks(self._dilation_key)
        if self._role is LEADER:
            noise.set_as_initiator()
            outbound_prologue = PROLOGUE_LEADER
            inbound_prologue = PROLOGUE_FOLLOWER
        else:
            noise.set_as_responder()
            outbound_prologue = PROLOGUE_FOLLOWER
            inbound_prologue = PROLOGUE_LEADER
        p = DilatedConnectionProtocol(self._eventual_queue, self._role,
                                      description,
                                      self, noise,
                                      outbound_prologue, inbound_prologue)
        return p

    @m.state(initial=True)
    def connecting(self):
        pass  # pragma: no cover

    @m.state()
    def connected(self):
        pass  # pragma: no cover

    @m.state(terminal=True)
    def stopped(self):
        pass  # pragma: no cover

    # TODO: unify the tense of these method-name verbs

    # add_relay() and got_hints() are called by the Manager as it receives
    # messages from our peer. stop() is called when the Manager shuts down
    @m.input()
    def add_relay(self, hint_objs):
        pass

    @m.input()
    def got_hints(self, hint_objs):
        pass

    @m.input()
    def stop(self):
        pass

    # called by ourselves, when _start_listener() is ready
    @m.input()
    def listener_ready(self, hint_objs):
        pass

    # called when DilatedConnectionProtocol submits itself, after KCM
    # received
    @m.input()
    def add_candidate(self, c):
        pass

    # called by ourselves, via consider()
    @m.input()
    def accept(self, c):
        pass


    @m.output()
    def use_hints(self, hint_objs):
        self._use_hints(hint_objs)

    @m.output()
    def publish_hints(self, hint_objs):
        self._publish_hints(hint_objs)

    def _publish_hints(self, hint_objs):
        self._manager.send_hints([encode_hint(h) for h in hint_objs])

    @m.output()
    def consider(self, c):
        self._contenders.add(c)
        if self._role is LEADER:
            # for now, just accept the first one. TODO: be clever.
            self._eventual_queue.eventually(self.accept, c)
        else:
            # the follower always uses the first contender, since that's the
            # only one the leader picked
            self._eventual_queue.eventually(self.accept, c)

    @m.output()
    def select_and_stop_remaining(self, c):
        self._winning_connection = c
        self._contenders.clear()  # we no longer care who else came close
        # remove this winner from the losers, so we don't shut it down
        self._pending_connections.discard(c)
        # shut down losing connections
        self.stop_listeners()  # TODO: maybe keep it open? NAT/p2p assist
        self.stop_pending_connectors()
        self.stop_pending_connections()

        c.select(self._manager)  # subsequent frames go directly to the manager
        # c.select also wires up when_disconnected() to fire
        # manager.connector_connection_lost(). TODO: rename this, since the
        # Connector is no longer the one calling it
        if self._role is LEADER:
            # TODO: this should live in Connection
            c.send_record(KCM())  # leader sends KCM now
        self._manager.connector_connection_made(c)  # manager sends frames to Connection

    @m.output()
    def stop_everything(self):
        self.stop_listeners()
        self.stop_pending_connectors()
        self.stop_pending_connections()
        self.break_cycles()

    def stop_listeners(self):
        d = DeferredList([l.stopListening() for l in self._listeners])
        self._listeners.clear()
        return d  # synchronization for tests

    def stop_pending_connectors(self):
        for d in self._pending_connectors:
            d.cancel()

    def stop_pending_connections(self):
        d = self._pending_connections.when_next_empty()
        [c.disconnect() for c in self._pending_connections]
        return d

    def break_cycles(self):
        # help GC by forgetting references to things that reference us
        self._listeners.clear()
        self._pending_connectors.clear()
        self._pending_connections.clear()
        self._winning_connection = None

    connecting.upon(listener_ready, enter=connecting, outputs=[publish_hints])
    connecting.upon(add_relay, enter=connecting, outputs=[use_hints,
                                                          publish_hints])
    connecting.upon(got_hints, enter=connecting, outputs=[use_hints])
    connecting.upon(add_candidate, enter=connecting, outputs=[consider])
    connecting.upon(accept, enter=connected, outputs=[
                    select_and_stop_remaining])
    connecting.upon(stop, enter=stopped, outputs=[stop_everything])

    # once connected, we ignore everything except stop
    connected.upon(listener_ready, enter=connected, outputs=[])
    connected.upon(add_relay, enter=connected, outputs=[])
    connected.upon(got_hints, enter=connected, outputs=[])
    # TODO: tell them to disconnect? will they hang out forever? I *think*
    # they'll drop this once they get a KCM on the winning connection.
    connected.upon(add_candidate, enter=connected, outputs=[])
    connected.upon(accept, enter=connected, outputs=[])
    connected.upon(stop, enter=stopped, outputs=[stop_everything])

    # from Manager: start, got_hints, stop
    # maybe add_candidate, accept

    def start(self):
        if not self._no_listen and not self._tor:
            addresses = self._get_listener_addresses()
            self._start_listener(addresses)
        if self._transit_relays:
            self._publish_hints(self._transit_relays)
            self._use_hints(self._transit_relays)

    def _get_listener_addresses(self):
        addresses = ipaddrs.find_addresses()
        non_loopback_addresses = [a for a in addresses if a != "127.0.0.1"]
        if non_loopback_addresses:
            # some test hosts, including the appveyor VMs, *only* have
            # 127.0.0.1, and the tests will hang badly if we remove it.
            addresses = non_loopback_addresses
        return addresses

    def _start_listener(self, addresses):
        # TODO: listen on a fixed port, if possible, for NAT/p2p benefits, also
        # to make firewall configs easier
        # TODO: retain listening port between connection generations?
        ep = serverFromString(self._reactor, "tcp:0")
        f = InboundConnectionFactory(self)
        d = ep.listen(f)

        def _listening(lp):
            # lp is an IListeningPort
            self._listeners.add(lp)  # for shutdown and tests
            portnum = lp.getHost().port
            direct_hints = [DirectTCPV1Hint(to_unicode(addr), portnum, 0.0)
                            for addr in addresses]
            self.listener_ready(direct_hints)
        d.addCallback(_listening)
        d.addErrback(log.err)

    def _schedule_connection(self, delay, h, is_relay):
        ep = endpoint_from_hint_obj(h, self._tor, self._reactor)
        desc = describe_hint_obj(h, is_relay, self._tor)
        d = deferLater(self._reactor, delay,
                       self._connect, ep, desc, is_relay)
        d.addErrback(log.err)
        self._pending_connectors.add(d)

    def _use_hints(self, hints):
        # first, pull out all the relays, we'll connect to them later
        relays = []
        direct = defaultdict(list)
        for h in hints:
            if isinstance(h, RelayV1Hint):
                relays.append(h)
            else:
                direct[h.priority].append(h)
        delay = 0.0
        made_direct = False
        priorities = sorted(set(direct.keys()), reverse=True)
        for p in priorities:
            for h in direct[p]:
                if isinstance(h, TorTCPV1Hint) and not self._tor:
                    continue
                self._schedule_connection(delay, h, is_relay=False)
                made_direct = True
                # Make all direct connections immediately. Later, we'll change
                # the add_candidate() function to look at the priority when
                # deciding whether to accept a successful connection or not,
                # and it can wait for more options if it sees a higher-priority
                # one still running. But if we bail on that, we might consider
                # putting an inter-direct-hint delay here to influence the
                # process.
                # delay += 1.0

        if made_direct and not self._no_listen:
            # Prefer direct connections by stalling relay connections by a
            # few seconds. We don't wait until direct connections have
            # failed, because many direct hints will be to unused
            # local-network IP address, which won't answer, and can take the
            # full 30s TCP timeout to fail.
            #
            # If we didn't make any direct connections, or we're using
            # --no-listen, then we're probably going to have to use the
            # relay, so don't delay it at all.
            delay += self.RELAY_DELAY

        # It might be nice to wire this so that a failure in the direct hints
        # causes the relay hints to be used right away (fast failover). But
        # none of our current use cases would take advantage of that: if we
        # have any viable direct hints, then they're either going to succeed
        # quickly or hang for a long time.
        for r in relays:
            for h in r.hints:
                self._schedule_connection(delay, h, is_relay=True)
        # TODO:
        # if not contenders:
        #    raise TransitError("No contenders for connection")

    # TODO: add 2*TIMEOUT deadline for first generation, don't wait forever for
    # the initial connection

    def _connect(self, ep, description, is_relay=False):
        relay_handshake = None
        if is_relay:
            relay_handshake = build_sided_relay_handshake(self._dilation_key,
                                                          self._side)
        f = OutboundConnectionFactory(self, relay_handshake, description)
        d = ep.connect(f)
        # fires with protocol, or ConnectError

        def _connected(p):
            self._pending_connections.add(p)
            # c might not be in _pending_connections, if it turned out to be a
            # winner, which is why we use discard() and not remove()
            p.when_disconnected().addCallback(self._pending_connections.discard)
        d.addCallback(_connected)
        return d
Esempio n. 23
0
class Prediction:
    """MCMC prediction

    Parameters
    ----------
    ensemble : ndarray
        Ensemble of predictions. A 2d array (nxm) for n predictands and m
        ensemble members.
    temptype : str
        Type of sea temperature used for prediction.
    latlon : tuple or None, optional
        Optional tuple of the site location (lat, lon).
    prior_mean : float or None, optional
        Prior mean used for the prediction.
    prior_std : float or None, optional
        Prior sample standard deviation used for the prediction.
    modelparam_gridpoints : list of tuples or None, optional
        A list of one or more (lat, lon) points used to collect
        spatially-sensitive model parameters.
    analog_gridpoints : list of tuples or None, optional
        A list of one or more (lat, lon) points used for an analog prediction.
    """
    temptype = attr.ib()
    ensemble = attr.ib(validator=av.optional(av.instance_of(np.ndarray)))
    latlon = attr.ib(default=None,
                     validator=av.optional(av.instance_of(tuple)))
    prior_mean = attr.ib(default=None)
    prior_std = attr.ib(default=None)
    modelparam_gridpoints = attr.ib(default=None,
                                    validator=av.optional(
                                        av.instance_of(list)))
    analog_gridpoints = attr.ib(default=None,
                                validator=av.optional(av.instance_of(list)))

    def percentile(self, q=None, interpolation='nearest'):
        """Compute the qth ranked percentile from ensemble members.

        Parameters
        ----------
        q : float ,sequence of floats, or None, optional
            Percentiles (i.e. [0, 100]) to compute. Default is 5%, 50%, 95%.
        interpolation : str, optional
            Passed to numpy.percentile. Default is 'nearest'.

        Returns
        -------
        perc : ndarray
            A 2d (nxm) array of floats where n is the number of predictands in
            the ensemble and m is the number of percentiles ('len(q)').
        """
        if q is None:
            q = [5, 50, 95]
        q = np.array(q, dtype=np.float64, copy=True)

        # Because analog ensembles have 3 dims
        target_axis = list(range(self.ensemble.ndim))[1:]

        perc = np.percentile(self.ensemble,
                             q=q,
                             axis=target_axis,
                             interpolation=interpolation)
        return perc.T
Esempio n. 24
0
class Issue:
    '''
    Contents of individual issue.

    Attributes
    attachments: callable that accepts zero arguments and returns a sequence of
                 stream-like objects
    '''
    reader      = attr.ib(validator=instance_of(ReaderBase))
    uid         = attr.ib()
    author      = attr.ib(type=Person)
    author_role = attr.ib(default='', converter=_role)
    status      = attr.ib(default='')
    title       = attr.ib(default='')
    body        = attr.ib(default='')
    original_url         = attr.ib(default='')
    labels      = attr.ib(factory=list)
    assignees   = attr.ib(factory=list)
    created_at  = attr.ib(default=None, validator=optional(instance_of(datetime)))
    modified_at = attr.ib(default=None, validator=optional(instance_of(datetime)))
    fetched_at  = attr.ib(default=None, validator=optional(instance_of(datetime)))
    closed_at   = attr.ib(default=None, validator=optional(instance_of(datetime)))
    attachments = attr.ib(default=list)


    @attachments.validator
    def _check_if_callable(self, attribute, value):
        if not callable(value):
            raise ValueError('{!r} must be a callable that accepts zero arguments'.format(
                attribute.name
            ))


    def comments(self, sort_by='created_at', desc=False):
        '''Yield related comments'''
        yield from self.reader._get_comments(self, sort_by, desc)


    def events(self, sort_by='created_at', desc=False):
        '''Yield related events'''
        yield from self.reader._get_events(self, sort_by, desc)


    def feed(self, desc=False):
        '''
        Yield comments and events in chronological order

        Generates tuples of (object, string) where string describes the kind of
        object being returned
        '''
        sort_by = 'created_at'

        comments = self.comments(sort_by, desc)
        events = self.events(sort_by, desc)

        comment, event = _get_next(comments), _get_next(events)
        while True:
            if not comment and not event:
                break
            elif (not event and comment) \
            or (comment and (comment.created_at <= event.created_at)):
                yield comment, 'comment'
                comment = _get_next(comments)
            else:
                yield event, 'event'
                event = _get_next(events)
Esempio n. 25
0
class AudioPackFormat(ADMElement):
    audioPackFormatName = attrib(default=None,
                                 validator=instance_of(string_types))
    type = attrib(default=None, validator=instance_of(TypeDefinition))
    absoluteDistance = attrib(default=None)
    audioChannelFormats = attrib(default=Factory(list), repr=False)
    audioPackFormats = attrib(default=Factory(list), repr=False)
    importance = attrib(default=None, validator=optional(instance_of(int)))

    # attributes for type==Matrix
    # encode and decode pack references are a single binary many-many
    # relationship; only store one side
    encodePackFormats = attrib(default=Factory(list))
    inputPackFormat = attrib(default=None)
    outputPackFormat = attrib(default=None)

    # attributes for type==HOA
    normalization = attrib(default=None, validator=optional(instance_of(str)))
    nfcRefDist = attrib(default=None, validator=optional(instance_of(float)))
    screenRef = attrib(default=None, validator=optional(instance_of(bool)))

    audioChannelFormatIDRef = attrib(default=None)
    audioPackFormatIDRef = attrib(default=None)
    encodePackFormatIDRef = attrib(default=None)
    decodePackFormatIDRef = attrib(default=None)
    inputPackFormatIDRef = attrib(default=None)
    outputPackFormatIDRef = attrib(default=None)

    def lazy_lookup_references(self, adm):
        if self.audioChannelFormatIDRef is not None:
            self.audioChannelFormats = _lookup_elements(
                adm, self.audioChannelFormatIDRef)
            self.audioChannelFormatIDRef = None
        if self.audioPackFormatIDRef is not None:
            self.audioPackFormats = _lookup_elements(adm,
                                                     self.audioPackFormatIDRef)
            self.audioPackFormatIDRef = None

        def add_encodePackFormat(decode_pack, new_encode_pack):
            if not any(encode_pack is new_encode_pack
                       for encode_pack in decode_pack.encodePackFormats):
                decode_pack.encodePackFormats.append(new_encode_pack)

        if self.decodePackFormatIDRef is not None:
            for decode_pack in _lookup_elements(adm,
                                                self.decodePackFormatIDRef):
                add_encodePackFormat(decode_pack, self)
            self.decodePackFormatIDRef = None

        if self.encodePackFormatIDRef is not None:
            for encode_pack in _lookup_elements(adm,
                                                self.encodePackFormatIDRef):
                add_encodePackFormat(self, encode_pack)
            self.encodePackFormatIDRef = None

        if self.inputPackFormatIDRef is not None:
            self.inputPackFormat = adm.lookup_element(
                self.inputPackFormatIDRef)
            self.inputPackFormatIDRef = None

        if self.outputPackFormatIDRef is not None:
            self.outputPackFormat = adm.lookup_element(
                self.outputPackFormatIDRef)
            self.outputPackFormatIDRef = None
Esempio n. 26
0
def optional_instance_of(cls):
    # type: (Any) -> _ValidatorType[Optional[_T]]
    return validators.optional(validators.instance_of(cls))
Esempio n. 27
0
class Frequency(object):
    lowPass = attrib(default=None, validator=optional(instance_of(float)))
    highPass = attrib(default=None, validator=optional(instance_of(float)))
Esempio n. 28
0
class ConfigSchema(Schema):

    NAME: ClassVar[str] = "config"

    # model config

    entities: Dict[str, EntitySchema] = attr.ib(
        validator=non_empty,
        metadata={
            'help':
            "The entity types. The ID with which they are "
            "referenced by the relation types is the key they "
            "have in this dict."
        },
    )
    relations: List[RelationSchema] = attr.ib(
        validator=non_empty,
        metadata={
            'help':
            "The relation types. The ID with which they will be "
            "referenced in the edge lists is their index in this "
            "list."
        },
    )
    dimension: int = attr.ib(
        validator=positive,
        metadata={
            'help': "The dimension of the real space the embedding live "
            "in."
        },
    )
    init_scale: float = attr.ib(
        default=1e-3,
        validator=positive,
        metadata={
            'help':
            "If no initial embeddings are provided, they are "
            "generated by sampling each dimension from a "
            "centered normal distribution having this standard "
            "deviation. (For performance reasons, sampling isn't "
            "fully independent.)"
        },
    )
    max_norm: Optional[float] = attr.ib(
        default=None,
        validator=optional(positive),
        metadata={
            'help':
            "If set, rescale the embeddings if their norm "
            "exceeds this value."
        },
    )
    global_emb: bool = attr.ib(
        default=True,
        metadata={
            'help':
            "If enabled, add to each embedding a vector that is "
            "common to all the entities of a certain type. This "
            "vector is learned during training."
        },
    )
    comparator: Comparator = attr.ib(
        default=Comparator.COS,
        metadata={
            'help':
            "How the embeddings of the two sides of an edge "
            "(after having already undergone some processing) "
            "are compared to each other to produce a score."
        },
    )
    bias: bool = attr.ib(
        default=False,
        metadata={
            'help':
            "If enabled, withhold the first dimension of the "
            "embeddings from the comparator and instead use it "
            "as a bias, adding back to the score. Makes sense "
            "for logistic and softmax loss functions."
        },
    )
    loss_fn: LossFunction = attr.ib(
        default=LossFunction.RANKING,
        metadata={
            'help':
            "How the scores of positive edges and their "
            "corresponding negatives are evaluated."
        },
    )
    margin: float = attr.ib(
        default=0.1,
        metadata={
            'help':
            "When using ranking loss, this value controls the "
            "minimum separation between positive and negative "
            "scores, below which a (linear) loss is incured."
        },
    )

    # data config

    entity_path: str = attr.ib(metadata={
        'help':
        "The path of the directory containing entity count "
        "files."
    }, )
    edge_paths: List[str] = attr.ib(metadata={
        'help':
        "A list of paths to directories containing "
        "(partitioned) edgelists. Typically a single path is "
        "provided."
    }, )
    checkpoint_path: str = attr.ib(metadata={
        'help':
        "The path to the directory where checkpoints (and "
        "thus the output) will be written to. If checkpoints "
        "are found in it, training will resume from them."
    }, )
    init_path: Optional[str] = attr.ib(
        default=None,
        metadata={
            'help':
            "If set, it must be a path to a directory that "
            "contains initial values for the embeddings of all "
            "the entities of some types."
        },
    )

    # training config

    num_epochs: int = attr.ib(
        default=1,
        validator=non_negative,
        metadata={
            'help':
            "The number of times the training loop iterates over "
            "all the edges."
        },
    )
    num_edge_chunks: int = attr.ib(
        default=1,
        validator=positive,
        metadata={
            'help':
            "The number of equally-sized parts each bucket will "
            "be split into. Training will first proceed over all "
            "the first chunks of all buckets, then over all the "
            "second chunks, and so on. A higher value allows "
            "better mixing of partitions, at the cost of more "
            "time spent on I/O."
        },
    )
    bucket_order: BucketOrder = attr.ib(
        default=BucketOrder.INSIDE_OUT,
        metadata={'help': "The order in which to iterate over the buckets."},
    )
    workers: Optional[int] = attr.ib(
        default=None,
        validator=optional(positive),
        metadata={
            'help':
            "The number of worker processes for \"Hogwild!\" "
            "training. If not given, set to CPU count."
        },
    )
    batch_size: int = attr.ib(
        default=1000,
        validator=positive,
        metadata={'help': "The number of edges per batch."},
    )
    num_batch_negs: int = attr.ib(
        default=50,
        validator=non_negative,
        metadata={
            'help':
            "The number of negatives sampled from the batch, per "
            "positive edge."
        },
    )
    num_uniform_negs: int = attr.ib(
        default=50,
        validator=non_negative,
        metadata={
            'help':
            "The number of negatives uniformly sampled from the "
            "currently active partition, per positive edge."
        },
    )
    lr: float = attr.ib(
        default=1e-2,
        validator=non_negative,
        metadata={'help': "The learning rate for the optimizer."},
    )
    relation_lr: Optional[float] = attr.ib(
        default=None,
        validator=optional(non_negative),
        metadata={
            'help':
            "If set, the learning rate for the optimizer"
            "for relations. Otherwise, `lr' is used."
        },
    )
    eval_fraction: float = attr.ib(
        default=0.05,
        validator=non_negative,
        metadata={
            'help':
            "The fraction of edges withheld from training and "
            "used to track evaluation metrics during training."
        },
    )
    eval_num_batch_negs: int = attr.ib(
        default=1000,
        validator=non_negative,
        metadata={
            'help':
            "The value that overrides the number of negatives "
            "per positive edge sampled from the batch during the "
            "evaluation steps that occur before and after each "
            "training step."
        },
    )
    eval_num_uniform_negs: int = attr.ib(
        default=1000,
        validator=non_negative,
        metadata={
            'help':
            "The value that overrides the number of "
            "uniformly-sampled negatives per positive edge "
            "during the evaluation steps that occur before and "
            "after each training step."
        },
    )

    # expert options

    background_io: bool = attr.ib(
        default=False,
        metadata={'help': "Whether to do load/save in a background process."},
    )
    verbose: int = attr.ib(
        default=0,
        validator=non_negative,
        metadata={'help': "The verbosity level of logging, currently 0 or 1."},
    )
    hogwild_delay: float = attr.ib(
        default=2,
        validator=non_negative,
        metadata={
            'help':
            "The number of seconds by which to delay the start "
            "of all \"Hogwild!\" processes except the first one, "
            "on the first epoch."
        },
    )
    dynamic_relations: bool = attr.ib(
        default=False,
        metadata={
            'help':
            "If enabled, activates the dynamic relation mode, in "
            "which case, there must be a single relation type in "
            "the config (whose parameters will apply to all "
            "dynamic relations types) and there must be a file "
            "called dynamic_rel_count.txt in the entity path that "
            "contains the number of dynamic relations. In this "
            "mode, batches will contain edges of multiple "
            "relation types and negatives will be sampled "
            "differently."
        },
    )

    # distributed training config options

    num_machines: int = attr.ib(
        default=1,
        validator=positive,
        metadata={'help': "The number of machines for distributed training."},
    )
    num_partition_servers: int = attr.ib(
        default=-1,
        metadata={
            'help':
            "If -1, use trainer as partition servers. If 0, "
            "don't use partition servers (instead, swap "
            "partitions through disk). If >1, then that number "
            "of partition servers must be started manually."
        },
    )
    distributed_init_method: Optional[str] = attr.ib(
        default=None,
        metadata={
            'help':
            "A URI defining how to synchronize all the workers "
            "of a distributed run. Must start with a scheme "
            "(e.g., file:// or tcp://) supported by PyTorch."
        })
    distributed_tree_init_order: bool = attr.ib(
        default=True,
        metadata={
            'help':
            "If enabled, then distributed training can occur on "
            "a bucket only if at least one of its partitions was "
            "already trained on before in the same round (or if "
            "one of its partitions is 0, for bootstrapping)."
        },
    )

    # Additional global validation.

    def __attrs_post_init__(self):
        for rel_id, rel_config in enumerate(self.relations):
            if rel_config.lhs not in self.entities:
                raise ValueError("Relation type %s (#%d) has an unknown "
                                 "left-hand side entity type %s" %
                                 (rel_config.name, rel_id, rel_config.lhs))
            if rel_config.rhs not in self.entities:
                raise ValueError("Relation type %s (#%d) has an unknown "
                                 "right-hand side entity type %s" %
                                 (rel_config.name, rel_id, rel_config.rhs))
        if self.dynamic_relations:
            if len(self.relations) != 1:
                raise ValueError("When dynamic relations are in use only one "
                                 "relation type must be defined.")
        # TODO Check that all partitioned entity types have the same number of partitions
        # TODO Check that the batch size is a multiple of the batch negative number
        if self.loss_fn is LossFunction.LOGISTIC and self.comparator is Comparator.COS:
            print(
                "WARNING: You have logistic loss and cosine distance. Are you sure?"
            )
Esempio n. 29
0
class AudioObjectInteraction(ADMElement):
    onOffInteract = attrib(default=None, validator=optional(instance_of(bool)))
    gainInteract = attrib(default=None, validator=optional(instance_of(bool)))
    positionInteract = attrib(default=None,
                              validator=optional(instance_of(bool)))
Esempio n. 30
0
class PipenvMarkers(BaseRequirement):
    """System-level requirements - see PEP508 for more detail"""
    os_name = attrib(default=None,
                     validator=validators.optional(_validate_markers))
    sys_platform = attrib(default=None,
                          validator=validators.optional(_validate_markers))
    platform_machine = attrib(default=None,
                              validator=validators.optional(_validate_markers))
    platform_python_implementation = attrib(
        default=None, validator=validators.optional(_validate_markers))
    platform_release = attrib(default=None,
                              validator=validators.optional(_validate_markers))
    platform_system = attrib(default=None,
                             validator=validators.optional(_validate_markers))
    platform_version = attrib(default=None,
                              validator=validators.optional(_validate_markers))
    python_version = attrib(default=None,
                            validator=validators.optional(_validate_markers))
    python_full_version = attrib(
        default=None, validator=validators.optional(_validate_markers))
    implementation_name = attrib(
        default=None, validator=validators.optional(_validate_markers))
    implementation_version = attrib(
        default=None, validator=validators.optional(_validate_markers))

    @property
    def line_part(self):
        return " and ".join([
            "{0} {1}".format(k, v)
            for k, v in attr.asdict(self, filter=_filter_none).items()
        ])

    @property
    def pipfile_part(self):
        return {"markers": self.as_line}

    @classmethod
    def make_marker(cls, marker_string):
        marker = Marker(marker_string)
        marker_dict = {}
        for m in marker._markers:
            if isinstance(m, six.string_types):
                continue
            var, op, val = m
            if var.value in cls.attr_fields():
                marker_dict[var.value] = '{0} "{1}"'.format(op, val)
        return marker_dict

    @classmethod
    def from_line(cls, line):
        if ";" in line:
            line = line.rsplit(";", 1)[1].strip()
        marker_dict = cls.make_marker(line)
        return cls(**marker_dict)

    @classmethod
    def from_pipfile(cls, name, pipfile):
        found_keys = [k for k in pipfile.keys() if k in cls.attr_fields()]
        marker_strings = ["{0} {1}".format(k, pipfile[k]) for k in found_keys]
        if pipfile.get("markers"):
            marker_strings.append(pipfile.get("markers"))
        markers = {}
        for marker in marker_strings:
            marker_dict = cls.make_marker(marker)
            if marker_dict:
                markers.update(marker_dict)
        return cls(**markers)
Esempio n. 31
0
class AwsLambda(State):
    """Invoke a Lambda function.

    `See service docs for more details.
    <https://docs.aws.amazon.com/lambda/latest/dg/API_Invoke.html>`_

    :param FunctionName: AWS Lambda Function to call
    :param Payload: Data to provide to the Lambda Function as input
    :type Payload: :class:`Parameters`, :class:`JsonPath`, :class:`AWSHelperFn`, dict, str, or :class:`Enum`
    :param ClientContext:
       Up to 3583 bytes of base64-encoded data about the invoking client to pass to the function in the context object
    :type ClientContext: :class:`JsonPath`, :class:`AWSHelperFn`, str, or :class:`Enum`
    :param InvocationType: Determines how the Lambda Function is invoked
    :type InvocationType: :class:`JsonPath`, :class:`AWSHelperFn`, str, or :class:`Enum`
    :param Qualifier: Version or alias of the Lambda Function to invoke
    :type Qualifier: :class:`JsonPath`, :class:`AWSHelperFn`, str, or :class:`Enum`
    """

    _required_fields = (RequiredValue(
        "FunctionName", "AWS Lambda Task requires a function name."), )
    _resource_name = ServiceArn.AWSLAMBDA

    # TODO: FunctionName MUST have length 1 <= n <= 170
    #  Pattern: (arn:(aws[a-zA-Z-]*)?:lambda:)?([a-z]{2}(-gov)?-[a-z]+-\d{1}:)?(\d{12}:)?(function:)?([a-zA-Z0-9-_\.]+)(:(\$LATEST|[a-zA-Z0-9-_]+))?
    FunctionName = RHODES_ATTRIB(
        validator=optional(instance_of(AWS_LAMBDA_FUNCTION_TYPES)))
    Payload = RHODES_ATTRIB(validator=optional(
        instance_of(SERVICE_INTEGRATION_SIMPLE_VALUE_TYPES +
                    SERVICE_INTEGRATION_COMPLEX_VALUE_TYPES)))
    ClientContext = RHODES_ATTRIB(validator=optional(
        instance_of(SERVICE_INTEGRATION_SIMPLE_VALUE_TYPES)))
    # TODO: Step Functions seems to accept InvocationType in a state machine definition,
    #  but I'm still not convinced it's actually valid at execution time...
    InvocationType = RHODES_ATTRIB(
        default=None,
        validator=optional(
            instance_of(SERVICE_INTEGRATION_SIMPLE_VALUE_TYPES)))
    Qualifier = RHODES_ATTRIB(validator=optional(
        instance_of(SERVICE_INTEGRATION_SIMPLE_VALUE_TYPES)))

    @InvocationType.validator
    def _validator_invocationtype(self, attribute, value):
        if not isinstance(value, str):
            return

        if value not in AWS_LAMBDA_INVOCATION_TYPES:
            raise ValueError(
                f"'InvocationType' value must be in {AWS_LAMBDA_INVOCATION_TYPES}."
            )

    @ClientContext.validator
    def _validate_clientcontext(self, attribute, value):
        # pylint: disable=no-self-use,unused-argument
        if not isinstance(value, str):
            return

        max_length = 3583
        actual = len(value)
        if actual > max_length:
            raise ValueError(
                f"'ClientContext' length {actual} is larger than maximum {max_length}."
            )

    @Qualifier.validator
    def _validate_qualifier(self, attribute, value):
        # pylint: disable=no-self-use,unused-argument
        if not isinstance(value, str):
            return

        min_length = 1
        max_length = 128
        actual = len(value)
        if not min_length <= actual <= max_length:
            raise ValueError(
                f"'ClientContext' length {actual} is outside allowed range from {min_length} to {max_length}."
            )
Esempio n. 32
0
class VCSRequirement(FileRequirement):
    editable = attrib(default=None)
    uri = attrib(default=None)
    path = attrib(default=None, validator=validators.optional(_validate_path))
    vcs = attrib(validator=validators.optional(_validate_vcs), default=None)
    # : vcs reference name (branch / commit / tag)
    ref = attrib(default=None)
    subdirectory = attrib(default=None)
    name = attrib()
    link = attrib()
    req = attrib()
    _INCLUDE_FIELDS = ("editable", "uri", "path", "vcs", "ref", "subdirectory",
                       "name", "link", "req")

    @link.default
    def get_link(self):
        return build_vcs_link(
            self.vcs,
            _clean_git_uri(self.uri),
            name=self.name,
            ref=self.ref,
            subdirectory=self.subdirectory,
        )

    @name.default
    def get_name(self):
        return self.link.egg_fragment or self.req.name if self.req else ""

    @property
    def vcs_uri(self):
        uri = self.uri
        if not any(uri.startswith("{0}+".format(vcs)) for vcs in VCS_LIST):
            uri = "{0}+{1}".format(self.vcs, uri)
        return uri

    @req.default
    def get_requirement(self):
        prefix = "-e " if self.editable else ""
        line = "{0}{1}".format(prefix, self.link.url)
        req = first(requirements.parse(line))
        if self.path and self.link and self.link.scheme.startswith("file"):
            req.local_file = True
            req.path = self.path
        if self.editable:
            req.editable = True
        req.link = self.link
        if (self.uri != self.link.url and "git+ssh://" in self.link.url
                and "git+git@" in self.uri):
            req.line = _strip_ssh_from_git_uri(req.line)
            req.uri = _strip_ssh_from_git_uri(req.uri)
        if not req.name:
            raise ValueError(
                "pipenv requires an #egg fragment for version controlled "
                "dependencies. Please install remote dependency "
                "in the form {0}#egg=<package-name>.".format(req.uri))
        if self.vcs and not req.vcs:
            req.vcs = self.vcs
        if self.ref and not req.revision:
            req.revision = self.ref
        return req

    @classmethod
    def from_pipfile(cls, name, pipfile):
        creation_args = {}
        pipfile_keys = [
            k for k in ("ref", "vcs", "subdirectory", "path", "editable",
                        "file", "uri") + VCS_LIST if k in pipfile
        ]
        for key in pipfile_keys:
            if key in VCS_LIST:
                creation_args["vcs"] = key
                composed_uri = _clean_git_uri("{0}+{1}".format(
                    key, pipfile.get(key))).lstrip("{0}+".format(key))
                is_url = is_valid_url(
                    pipfile.get(key)) or is_valid_url(composed_uri)
                target_key = "uri" if is_url else "path"
                creation_args[target_key] = pipfile.get(key)
            else:
                creation_args[key] = pipfile.get(key)
        creation_args["name"] = name
        return cls(**creation_args)

    @classmethod
    def from_line(cls, line, editable=None):
        path = None
        if line.startswith("-e "):
            editable = True
            line = line.split(" ", 1)[1]
        vcs_line = _clean_git_uri(line)
        vcs_method, vcs_location = _split_vcs_method(vcs_line)
        if not is_valid_url(vcs_location) and os.path.exists(vcs_location):
            path = get_converted_relative_path(vcs_location)
            vcs_location = path_to_url(os.path.abspath(vcs_location))
        link = Link(vcs_line)
        name = link.egg_fragment
        uri = link.url_without_fragment
        if "git+git@" in line:
            uri = _strip_ssh_from_git_uri(uri)
        subdirectory = link.subdirectory_fragment
        ref = None
        if "@" in link.show_url:
            uri, ref = uri.rsplit("@", 1)
        return cls(
            name=name,
            ref=ref,
            vcs=vcs_method,
            subdirectory=subdirectory,
            link=link,
            path=path,
            editable=editable,
            uri=uri,
        )

    @property
    def line_part(self):
        """requirements.txt compatible line part sans-extras"""
        if self.req:
            return self.req.line
        base = "{0}".format(self.link)
        if self.editable:
            base = "-e {0}".format(base)
        return base

    @staticmethod
    def _choose_vcs_source(pipfile):
        src_keys = [k for k in pipfile.keys() if k in ["path", "uri", "file"]]
        if src_keys:
            chosen_key = first(src_keys)
            vcs_type = pipfile.pop("vcs")
            _, pipfile_url = _split_vcs_method(pipfile.get(chosen_key))
            pipfile[vcs_type] = pipfile_url
            for removed in src_keys:
                _ = pipfile.pop(removed)
        return pipfile

    @property
    def pipfile_part(self):
        pipfile_dict = attr.asdict(self, filter=_filter_none).copy()
        if "vcs" in pipfile_dict:
            pipfile_dict = self._choose_vcs_source(pipfile_dict)
        name = pipfile_dict.pop("name")
        return {name: pipfile_dict}
Esempio n. 33
0
class SentrySettings(Section):
    dsn: Optional[str] = attrib(validator=optional(instance_of(str)),
                                default=None)
Esempio n. 34
0
import os.path
import yaml

from attr import attributes, attr, validators, asdict


valid_str = validators.instance_of(str)

optional_str_attr = attr(
    validator=validators.optional(valid_str),
    default='',
)


@attributes
class Config:
    username = optional_str_attr
    password = optional_str_attr


def get_config(path):

    if not os.path.exists(path):
        return Config()

    with open(path) as f:
        config = yaml.load(f)

    return Config(
        username=config['username'],
        password=config['password'],
Esempio n. 35
0
class UrlItem:
    """A URL in a toctree."""

    # regex should match sphinx.util.url_re
    url: str = attr.ib(validator=[instance_of(str), matches_re(r".+://.*")])
    title: Optional[str] = attr.ib(None, validator=optional(instance_of(str)))
Esempio n. 36
0
 def test_success_with_none(self, validator):
     """
     Nothing happens if None.
     """
     v = optional(validator)
     v(None, simple_attr("test"), None)
Esempio n. 37
0
class DataObj:
    """
    Class that holds a data object (either a note or a bookmark).

    Attributes:

    [Required to pass when creating a new object]

    - **type** -> "note" or "bookmark"

     **Note**:
    - title

    **Bookmark**:

    - url

    [Optional attrs that if passed, will be set by the class]

    - tags
    - content
    - path

    [Handled by the code]

    - id
    - date

    For bookmarks,
    Run `process_bookmark_url()` once you've created it.

    For both types, run `insert()` if you want to create a new file in
    the db with their contents.
    """

    __searchable__ = ["title", "content", "tags"]

    id: Optional[int] = attrib(validator=optional(instance_of(int)),
                               default=None)
    type: str = attrib(validator=instance_of(str))
    title: str = attrib(validator=instance_of(str), default="")
    content: str = attrib(validator=instance_of(str), default="")
    tags: List[str] = attrib(validator=instance_of(list), default=[])
    url: Optional[str] = attrib(validator=optional(instance_of(str)),
                                default=None)
    date: Optional[datetime] = attrib(
        validator=optional(instance_of(datetime)),
        default=None,
    )
    path: str = attrib(validator=instance_of(str), default="")
    fullpath: Optional[str] = attrib(validator=optional(instance_of(str)),
                                     default=None)

    def process_bookmark_url(self):
        """Process url to get content for bookmark"""
        if self.type not in (
                "bookmark", "pocket_bookmark") or not validators.url(self.url):
            return None

        try:
            url_request = requests.get(self.url)
        except Exception:
            flash(f"Could not retrieve {self.url}\n", "error")
            self.wipe()
            return

        try:
            parsed_html = BeautifulSoup(url_request.text,
                                        features="html.parser")
        except Exception:
            flash(f"Could not parse {self.url}\n", "error")
            self.wipe()
            return

        try:
            self.content = self.extract_content(parsed_html)
        except Exception:
            flash(f"Could not extract content from {self.url}\n", "error")
            return

        parsed_title = parsed_html.title
        self.title = parsed_title.string if parsed_title is not None else self.url

    def wipe(self):
        """Resets and invalidates dataobj"""
        self.title = ""
        self.content = ""

    def extract_content(self, beautsoup):
        """converts html bookmark url to optimized markdown"""

        stripped_tags = ["footer", "nav"]
        url = self.url.rstrip("/")

        for tag in stripped_tags:
            if getattr(beautsoup, tag):
                getattr(beautsoup, tag).extract()
        resources = beautsoup.find_all(["a", "img"])
        for tag in resources:
            if tag.name == "a":
                if tag.has_attr("href") and (tag["href"].startswith("/")):
                    tag["href"] = urljoin(url, tag["href"])

                # check it's a normal link and not some sort of image
                # string returns the text content of the tag
                if not tag.string:
                    # delete tag
                    tag.decompose()

            elif (tag.name == "img" and tag.has_attr("src") and
                  (tag["src"].startswith("/") or tag["src"].startswith("./"))):

                tag["src"] = urljoin(url, tag["src"])

        res = html2text(str(beautsoup), bodywidth=0)
        return res

    def validate(self):
        """Verifies that the content matches required validation constraints"""
        valid_url = (self.type != "bookmark"
                     or self.type != "pocket_bookmark") or (isinstance(
                         self.url, str) and validators.url(self.url))

        valid_title = isinstance(self.title, str) and self.title != ""
        valid_content = self.type not in ("bookmark",
                                          "pocket_bookmark") or isinstance(
                                              self.content, str)
        return valid_url and valid_title and valid_content

    def insert(self):
        """Creates a new file with the object's attributes"""
        if self.validate():

            helpers.set_max_id(helpers.get_max_id() + 1)
            self.id = helpers.get_max_id()
            self.date = datetime.now()

            hooks = helpers.load_hooks()

            hooks.before_dataobj_create(self)
            data = {
                "type": self.type,
                "title": str(self.title),
                "date": self.date.strftime("%x").replace("/", "-"),
                "tags": self.tags,
                "id": self.id,
                "path": self.path,
            }
            if self.type == "bookmark" or self.type == "pocket_bookmark":
                data["url"] = self.url

            # convert to markdown file
            dataobj = frontmatter.Post(self.content)
            dataobj.metadata = data
            self.fullpath = str(
                create(
                    frontmatter.dumps(dataobj),
                    f"{self.id}-{dataobj['title']}",
                    path=self.path,
                ))

            hooks.on_dataobj_create(self)
            self.index()
            return self.id
        return False

    def index(self):
        return add_to_index(self)

    @classmethod
    def from_md(cls, md_content: str):
        """
        Class method to generate new dataobj from a well formatted markdown string

        Call like this:

        ```python
        Dataobj.from_md(content)

        ```
        """
        data = frontmatter.loads(md_content)
        dataobj = {}
        dataobj["content"] = data.content
        for pair in ["tags", "id", "title", "path"]:
            try:
                dataobj[pair] = data[pair]
            except KeyError:
                # files sometimes get moved temporarily by applications while you edit
                # this can create bugs where the data is not loaded correctly
                # this handles that scenario as validation will simply fail and the event will
                # be ignored
                break

        dataobj["type"] = "processed-dataobj"
        return cls(**dataobj)
Esempio n. 38
0
 def test_success_with_type(self):
     """
     Nothing happens if types match.
     """
     v = optional(instance_of(int))
     v(None, simple_attr("test"), 42)
Esempio n. 39
0
 def test_success_with_type(self):
     """
     Nothing happens if types match.
     """
     v = optional(instance_of(int))
     v(None, simple_attr("test"), 42)