class QCircuitMachine(RuleBasedStateMachine):
    """Build a Hypothesis rule based state machine for constructing, transpiling
    and simulating a series of random QuantumCircuits.

    Build circuits with up to QISKIT_RANDOM_QUBITS qubits, apply a random
    selection of gates from qiskit.circuit.library with randomly selected
    qargs, cargs, and parameters. At random intervals, transpile the circuit for
    a random backend with a random optimization level and simulate both the
    initial and the transpiled circuits to verify that their counts are the
    same.
    """

    qubits = Bundle("qubits")
    clbits = Bundle("clbits")

    backend = Aer.get_backend("aer_simulator")
    max_qubits = int(backend.configuration().n_qubits / 2)

    # Limit reg generation for more interesting circuits
    max_qregs = 3
    max_cregs = 3

    def __init__(self):
        super().__init__()
        self.qc = QuantumCircuit()
        self.enable_variadic = bool(variadic_gates)

    @precondition(lambda self: len(self.qc.qubits) < self.max_qubits)
    @precondition(lambda self: len(self.qc.qregs) < self.max_qregs)
    @rule(target=qubits, n=st.integers(min_value=1, max_value=max_qubits))
    def add_qreg(self, n):
        """Adds a new variable sized qreg to the circuit, up to max_qubits."""
        n = min(n, self.max_qubits - len(self.qc.qubits))
        qreg = QuantumRegister(n)
        self.qc.add_register(qreg)
        return multiple(*list(qreg))

    @precondition(lambda self: len(self.qc.cregs) < self.max_cregs)
    @rule(target=clbits, n=st.integers(1, 5))
    def add_creg(self, n):
        """Add a new variable sized creg to the circuit."""
        creg = ClassicalRegister(n)
        self.qc.add_register(creg)
        return multiple(*list(creg))

    # Gates of various shapes

    @precondition(lambda self: self.qc.num_qubits > 0 and self.qc.num_clbits > 0)
    @rule(n_arguments=st.sampled_from(sorted(BASE_INSTRUCTIONS.keys())), data=st.data())
    def add_gate(self, n_arguments, data):
        """Append a random fixed gate to the circuit."""
        n_qubits, n_clbits, n_params = n_arguments
        gate_class = data.draw(st.sampled_from(BASE_INSTRUCTIONS[n_qubits, n_clbits, n_params]))
        qubits = data.draw(st.lists(self.qubits, min_size=n_qubits, max_size=n_qubits, unique=True))
        clbits = data.draw(st.lists(self.clbits, min_size=n_clbits, max_size=n_clbits, unique=True))
        params = data.draw(
            st.lists(
                st.floats(
                    allow_nan=False, allow_infinity=False, min_value=-10 * pi, max_value=10 * pi
                ),
                min_size=n_params,
                max_size=n_params,
            )
        )
        self.qc.append(gate_class(*params), qubits, clbits)

    @precondition(lambda self: self.enable_variadic)
    @rule(gate=st.sampled_from(variadic_gates), qargs=st.lists(qubits, min_size=1, unique=True))
    def add_variQ_gate(self, gate, qargs):
        """Append a gate with a variable number of qargs."""
        self.qc.append(gate(len(qargs)), qargs)

    @precondition(lambda self: len(self.qc.data) > 0)
    @rule(carg=clbits, data=st.data())
    def add_c_if_last_gate(self, carg, data):
        """Modify the last gate to be conditional on a classical register."""
        creg = self.qc.find_bit(carg).registers[0][0]
        val = data.draw(st.integers(min_value=0, max_value=2 ** len(creg) - 1))

        last_gate = self.qc.data[-1]

        # Conditional instructions are not supported
        assume(isinstance(last_gate[0], Gate))

        last_gate[0].c_if(creg, val)

    # Properties to check

    @invariant()
    def qasm(self):
        """After each circuit operation, it should be possible to build QASM."""
        self.qc.qasm()

    @precondition(lambda self: any(isinstance(d[0], Measure) for d in self.qc.data))
    @rule(conf=transpiler_conf())
    def equivalent_transpile(self, conf):
        """Simulate, transpile and simulate the present circuit. Verify that the
        counts are not significantly different before and after transpilation.

        """
        backend, opt_level, layout_method, routing_method, scheduling_method = conf

        assume(backend is None or backend.configuration().n_qubits >= len(self.qc.qubits))

        print(
            f"Evaluating circuit at level {opt_level} on {backend} "
            f"using layout_method={layout_method} routing_method={routing_method} "
            f"and scheduling_method={scheduling_method}:\n{self.qc.qasm()}"
        )

        shots = 4096

        # Note that there's no transpilation here, which is why the gates are limited to only ones
        # that Aer supports natively.
        aer_counts = self.backend.run(self.qc, shots=shots).result().get_counts()

        try:
            xpiled_qc = transpile(
                self.qc,
                backend=backend,
                optimization_level=opt_level,
                layout_method=layout_method,
                routing_method=routing_method,
                scheduling_method=scheduling_method,
            )
        except Exception as e:
            failed_qasm = "Exception caught during transpilation of circuit: \n{}".format(
                self.qc.qasm()
            )
            raise RuntimeError(failed_qasm) from e

        xpiled_aer_counts = self.backend.run(xpiled_qc, shots=shots).result().get_counts()

        count_differences = dicts_almost_equal(aer_counts, xpiled_aer_counts, 0.05 * shots)

        assert (
            count_differences == ""
        ), "Counts not equivalent: {}\nFailing QASM Input:\n{}\n\nFailing QASM Output:\n{}".format(
            count_differences, self.qc.qasm(), xpiled_qc.qasm()
        )
Esempio n. 2
0
class verifyingstatemachine(RuleBasedStateMachine):
    """This defines the set of acceptable operations on a Mercurial repository
    using Hypothesis's RuleBasedStateMachine.

    The general concept is that we manage multiple repositories inside a
    repos/ directory in our temporary test location. Some of these are freshly
    inited, some are clones of the others. Our current working directory is
    always inside one of these repositories while the tests are running.

    Hypothesis then performs a series of operations against these repositories,
    including hg commands, generating contents and editing the .hgrc file.
    If these operations fail in unexpected ways or behave differently in
    different configurations of Mercurial, the test will fail and a minimized
    .t test file will be written to the hypothesis-generated directory to
    exhibit that failure.

    Operations are defined as methods with @rule() decorators. See the
    Hypothesis documentation at
    http://hypothesis.readthedocs.org/en/release/stateful.html for more
    details."""

    # A bundle is a reusable collection of previously generated data which may
    # be provided as arguments to future operations.
    repos = Bundle('repos')
    paths = Bundle('paths')
    contents = Bundle('contents')
    branches = Bundle('branches')
    committimes = Bundle('committimes')

    def __init__(self):
        super(verifyingstatemachine, self).__init__()
        self.repodir = os.path.join(testtmp, "repos")
        if os.path.exists(self.repodir):
            shutil.rmtree(self.repodir)
        os.chdir(testtmp)
        self.log = []
        self.failed = False
        self.configperrepo = {}
        self.all_extensions = set()
        self.non_skippable_extensions = set()

        self.mkdirp("repos")
        self.cd("repos")
        self.mkdirp("repo1")
        self.cd("repo1")
        self.hg("init")

    def teardown(self):
        """On teardown we clean up after ourselves as usual, but we also
        do some additional testing: We generate a .t file based on our test
        run using run-test.py -i to get the correct output.

        We then test it in a number of other configurations, verifying that
        each passes the same test."""
        super(verifyingstatemachine, self).teardown()
        try:
            shutil.rmtree(self.repodir)
        except OSError:
            pass
        ttest = os.linesep.join("  " + l for l in self.log)
        os.chdir(testtmp)
        path = os.path.join(testtmp, "test-generated.t")
        with open(path, 'w') as o:
            o.write(ttest + os.linesep)
        with open(os.devnull, "w") as devnull:
            rewriter = subprocess.Popen(
                [runtests, "--local", "-i", path],
                stdin=subprocess.PIPE,
                stdout=devnull,
                stderr=devnull,
            )
            rewriter.communicate("yes")
            with open(path, 'r') as i:
                ttest = i.read()

        e = None
        if not self.failed:
            try:
                output = subprocess.check_output(
                    [runtests, path, "--local", "--pure"],
                    stderr=subprocess.STDOUT)
                assert "Ran 1 test" in output, output
                for ext in (self.all_extensions -
                            self.non_skippable_extensions):
                    tf = os.path.join(testtmp,
                                      "test-generated-no-%s.t" % (ext, ))
                    with open(tf, 'w') as o:
                        for l in ttest.splitlines():
                            if l.startswith("  $ hg"):
                                l = l.replace(
                                    "--config %s=" %
                                    (extensionconfigkey(ext), ), "")
                            o.write(l + os.linesep)
                    with open(tf, 'r') as r:
                        t = r.read()
                        assert ext not in t, t
                    output = subprocess.check_output([
                        runtests,
                        tf,
                        "--local",
                    ],
                                                     stderr=subprocess.STDOUT)
                    assert "Ran 1 test" in output, output
            except subprocess.CalledProcessError as e:
                note(e.output)
        if self.failed or e is not None:
            with open(savefile, "wb") as o:
                o.write(ttest)
        if e is not None:
            raise e

    def execute_step(self, step):
        try:
            return super(verifyingstatemachine, self).execute_step(step)
        except (HypothesisException, KeyboardInterrupt):
            raise
        except Exception:
            self.failed = True
            raise

    # Section: Basic commands.
    def mkdirp(self, path):
        if os.path.exists(path):
            return
        self.log.append("$ mkdir -p -- %s" %
                        (pipes.quote(os.path.relpath(path)), ))
        os.makedirs(path)

    def cd(self, path):
        path = os.path.relpath(path)
        if path == ".":
            return
        os.chdir(path)
        self.log.append("$ cd -- %s" % (pipes.quote(path), ))

    def hg(self, *args):
        extra_flags = []
        for key, value in self.config.items():
            extra_flags.append("--config")
            extra_flags.append("%s=%s" % (key, value))
        self.command("hg", *(tuple(extra_flags) + args))

    def command(self, *args):
        self.log.append("$ " + ' '.join(map(pipes.quote, args)))
        subprocess.check_output(args, stderr=subprocess.STDOUT)

    # Section: Set up basic data
    # This section has no side effects but generates data that we will want
    # to use later.
    @rule(target=paths,
          source=st.lists(files, min_size=1).map(lambda l: os.path.join(*l)))
    def genpath(self, source):
        return source

    @rule(target=committimes,
          when=datetimes(min_year=1970, max_year=2038) | st.none())
    def gentime(self, when):
        return when

    @rule(target=contents,
          content=st.one_of(st.binary(),
                            st.text().map(lambda x: x.encode('utf-8'))))
    def gencontent(self, content):
        return content

    @rule(
        target=branches,
        name=safetext,
    )
    def genbranch(self, name):
        return name

    @rule(target=paths, source=paths)
    def lowerpath(self, source):
        return source.lower()

    @rule(target=paths, source=paths)
    def upperpath(self, source):
        return source.upper()

    # Section: Basic path operations
    @rule(path=paths, content=contents)
    def writecontent(self, path, content):
        self.unadded_changes = True
        if os.path.isdir(path):
            return
        parent = os.path.dirname(path)
        if parent:
            try:
                self.mkdirp(parent)
            except OSError:
                # It may be the case that there is a regular file that has
                # previously been created that has the same name as an ancestor
                # of the current path. This will cause mkdirp to fail with this
                # error. We just turn this into a no-op in that case.
                return
        with open(path, 'wb') as o:
            o.write(content)
        self.log.append(("$ python -c 'import binascii; "
                         "print(binascii.unhexlify(\"%s\"))' > %s") % (
                             binascii.hexlify(content),
                             pipes.quote(path),
                         ))

    @rule(path=paths)
    def addpath(self, path):
        if os.path.exists(path):
            self.hg("add", "--", path)

    @rule(path=paths)
    def forgetpath(self, path):
        if os.path.exists(path):
            with acceptableerrors("file is already untracked", ):
                self.hg("forget", "--", path)

    @rule(s=st.none() | st.integers(0, 100))
    def addremove(self, s):
        args = ["addremove"]
        if s is not None:
            args.extend(["-s", str(s)])
        self.hg(*args)

    @rule(path=paths)
    def removepath(self, path):
        if os.path.exists(path):
            with acceptableerrors(
                    'file is untracked',
                    'file has been marked for add',
                    'file is modified',
            ):
                self.hg("remove", "--", path)

    @rule(
        message=safetext,
        amend=st.booleans(),
        when=committimes,
        addremove=st.booleans(),
        secret=st.booleans(),
        close_branch=st.booleans(),
    )
    def maybecommit(self, message, amend, when, addremove, secret,
                    close_branch):
        command = ["commit"]
        errors = ["nothing changed"]
        if amend:
            errors.append("cannot amend public changesets")
            command.append("--amend")
        command.append("-m" + pipes.quote(message))
        if secret:
            command.append("--secret")
        if close_branch:
            command.append("--close-branch")
            errors.append("can only close branch heads")
        if addremove:
            command.append("--addremove")
        if when is not None:
            if when.year == 1970:
                errors.append('negative date value')
            if when.year == 2038:
                errors.append('exceeds 32 bits')
            command.append("--date=%s" %
                           (when.strftime('%Y-%m-%d %H:%M:%S %z'), ))

        with acceptableerrors(*errors):
            self.hg(*command)

    # Section: Repository management
    @property
    def currentrepo(self):
        return os.path.basename(os.getcwd())

    @property
    def config(self):
        return self.configperrepo.setdefault(self.currentrepo, {})

    @rule(
        target=repos,
        source=repos,
        name=reponames,
    )
    def clone(self, source, name):
        if not os.path.exists(os.path.join("..", name)):
            self.cd("..")
            self.hg("clone", source, name)
            self.cd(name)
        return name

    @rule(
        target=repos,
        name=reponames,
    )
    def fresh(self, name):
        if not os.path.exists(os.path.join("..", name)):
            self.cd("..")
            self.mkdirp(name)
            self.cd(name)
            self.hg("init")
        return name

    @rule(name=repos)
    def switch(self, name):
        self.cd(os.path.join("..", name))
        assert self.currentrepo == name
        assert os.path.exists(".hg")

    @rule(target=repos)
    def origin(self):
        return "repo1"

    @rule()
    def pull(self, repo=repos):
        with acceptableerrors(
                "repository default not found",
                "repository is unrelated",
        ):
            self.hg("pull")

    @rule(newbranch=st.booleans())
    def push(self, newbranch):
        with acceptableerrors(
                "default repository not configured",
                "no changes found",
        ):
            if newbranch:
                self.hg("push", "--new-branch")
            else:
                with acceptableerrors("creates new branches"):
                    self.hg("push")

    # Section: Simple side effect free "check" operations
    @rule()
    def log(self):
        self.hg("log")

    @rule()
    def verify(self):
        self.hg("verify")

    @rule()
    def diff(self):
        self.hg("diff", "--nodates")

    @rule()
    def status(self):
        self.hg("status")

    @rule()
    def export(self):
        self.hg("export")

    # Section: Branch management
    @rule()
    def checkbranch(self):
        self.hg("branch")

    @rule(branch=branches)
    def switchbranch(self, branch):
        with acceptableerrors(
                'cannot use an integer as a name',
                'cannot be used in a name',
                'a branch of the same name already exists',
                'is reserved',
        ):
            self.hg("branch", "--", branch)

    @rule(branch=branches, clean=st.booleans())
    def update(self, branch, clean):
        with acceptableerrors(
                'unknown revision',
                'parse error',
        ):
            if clean:
                self.hg("update", "-C", "--", branch)
            else:
                self.hg("update", "--", branch)

    # Section: Extension management
    def hasextension(self, extension):
        return extensionconfigkey(extension) in self.config

    def commandused(self, extension):
        assert extension in self.all_extensions
        self.non_skippable_extensions.add(extension)

    @rule(extension=extensions)
    def addextension(self, extension):
        self.all_extensions.add(extension)
        self.config[extensionconfigkey(extension)] = ""

    @rule(extension=extensions)
    def removeextension(self, extension):
        self.config.pop(extensionconfigkey(extension), None)

    # Section: Commands from the shelve extension
    @rule()
    @precondition(lambda self: self.hasextension("shelve"))
    def shelve(self):
        self.commandused("shelve")
        with acceptableerrors("nothing changed"):
            self.hg("shelve")

    @rule()
    @precondition(lambda self: self.hasextension("shelve"))
    def unshelve(self):
        self.commandused("shelve")
        with acceptableerrors("no shelved changes to apply"):
            self.hg("unshelve")
Esempio n. 3
0
class DynamicMachine(RuleBasedStateMachine):
    @rule(value=Bundle(u'hi'))
    def test_stuff(x):
        pass
Esempio n. 4
0
class InitiatorMixin:
    def __init__(self):
        super().__init__()
        self.used_secrets = set()
        self.processed_secret_requests = set()
        self.initiated = set()

    def _action_init_initiator(self,
                               transfer: TransferDescriptionWithSecretState):
        channel = self.address_to_channel[transfer.target]
        if transfer.secrethash not in self.expected_expiry:
            self.expected_expiry[transfer.secrethash] = self.block_number + 10
        return ActionInitInitiator(
            transfer, [factories.make_route_from_channel(channel)])

    def _receive_secret_request(self,
                                transfer: TransferDescriptionWithSecretState):
        secrethash = sha256(transfer.secret).digest()
        return ReceiveSecretRequest(
            payment_identifier=transfer.payment_identifier,
            amount=transfer.amount,
            expiration=self.expected_expiry[transfer.secrethash],
            secrethash=secrethash,
            sender=transfer.target,
        )

    def _new_transfer_description(self, target, payment_id, amount, secret):
        self.used_secrets.add(secret)

        return TransferDescriptionWithSecretState(
            token_network_registry_address=self.token_network_registry_address,
            payment_identifier=payment_id,
            amount=amount,
            token_network_address=self.token_network_address,
            initiator=self.address,
            target=target,
            secret=secret,
        )

    def _invalid_authentic_secret_request(self, previous, action):
        result = node.state_transition(self.chain_state, action)
        if action.secrethash in self.processed_secret_requests or self._is_removed(
                previous):
            assert not result.events
        else:
            self.processed_secret_requests.add(action.secrethash)

    def _unauthentic_secret_request(self, action):
        result = node.state_transition(self.chain_state, action)
        assert not result.events

    def _available_amount(self, partner_address):
        netting_channel = self.address_to_channel[partner_address]
        return channel.get_distributable(netting_channel.our_state,
                                         netting_channel.partner_state)

    def _assume_channel_opened(self, action):
        assume(self.channel_opened(action.transfer.target))

    def _is_removed(self, action):
        expiry = self.expected_expiry[action.transfer.secrethash]
        return self.block_number >= expiry + DEFAULT_WAIT_BEFORE_LOCK_REMOVAL

    init_initiators = Bundle("init_initiators")

    @rule(
        target=init_initiators,
        partner=partners,
        payment_id=payment_id(),  # pylint: disable=no-value-for-parameter
        amount=integers(min_value=1, max_value=100),
        secret=secret(),  # pylint: disable=no-value-for-parameter
    )
    def valid_init_initiator(self, partner, payment_id, amount, secret):
        assume(amount <= self._available_amount(partner))
        assume(secret not in self.used_secrets)

        transfer = self._new_transfer_description(partner, payment_id, amount,
                                                  secret)
        action = self._action_init_initiator(transfer)
        result = node.state_transition(self.chain_state, action)

        assert event_types_match(result.events, SendLockedTransfer)

        self.initiated.add(transfer.secret)
        self.expected_expiry[transfer.secrethash] = self.block_number + 10

        return action

    @rule(
        partner=partners,
        payment_id=payment_id(),  # pylint: disable=no-value-for-parameter
        excess_amount=integers(min_value=1),
        secret=secret(),  # pylint: disable=no-value-for-parameter
    )
    def exceeded_capacity_init_initiator(self, partner, payment_id,
                                         excess_amount, secret):
        amount = self._available_amount(partner) + excess_amount
        transfer = self._new_transfer_description(partner, payment_id, amount,
                                                  secret)
        action = self._action_init_initiator(transfer)
        result = node.state_transition(self.chain_state, action)
        assert event_types_match(result.events, EventPaymentSentFailed)
        self.event("ActionInitInitiator failed: Amount exceeded")

    @rule(
        previous_action=init_initiators,
        partner=partners,
        payment_id=payment_id(),  # pylint: disable=no-value-for-parameter
        amount=integers(min_value=1),
    )
    def used_secret_init_initiator(self, previous_action, partner, payment_id,
                                   amount):
        assume(not self._is_removed(previous_action))
        secret = previous_action.transfer.secret
        transfer = self._new_transfer_description(partner, payment_id, amount,
                                                  secret)
        action = self._action_init_initiator(transfer)
        result = node.state_transition(self.chain_state, action)
        assert not result.events
        self.event("ActionInitInitiator failed: Secret already in use.")

    @rule(previous_action=init_initiators)
    def replay_init_initator(self, previous_action):
        assume(not self._is_removed(previous_action))
        result = node.state_transition(self.chain_state, previous_action)
        assert not result.events

    @rule(previous_action=init_initiators)
    def valid_secret_request(self, previous_action):
        action = self._receive_secret_request(previous_action.transfer)
        self._assume_channel_opened(previous_action)
        result = node.state_transition(self.chain_state, action)
        if action.secrethash in self.processed_secret_requests:
            assert not result.events
            self.event(
                "Valid SecretRequest dropped due to previous invalid one.")
        elif self._is_removed(previous_action):
            assert not result.events
            self.event(
                "Otherwise valid SecretRequest dropped due to expired lock.")
        else:
            assert event_types_match(result.events, SendSecretReveal)
            self.event("Valid SecretRequest accepted.")
            self.processed_secret_requests.add(action.secrethash)

    @rule(previous_action=init_initiators, amount=integers())
    def wrong_amount_secret_request(self, previous_action, amount):
        assume(amount != previous_action.transfer.amount)
        self._assume_channel_opened(previous_action)
        transfer = deepcopy(previous_action.transfer)
        transfer.amount = amount
        action = self._receive_secret_request(transfer)
        self._invalid_authentic_secret_request(previous_action, action)

    @rule(
        previous_action=init_initiators,
        secret=secret()  # pylint: disable=no-value-for-parameter
    )
    def secret_request_with_wrong_secrethash(self, previous_action, secret):
        assume(
            sha256_secrethash(secret) != sha256_secrethash(
                previous_action.transfer.secret))
        self._assume_channel_opened(previous_action)
        transfer = deepcopy(previous_action.transfer)
        transfer.secret = secret
        action = self._receive_secret_request(transfer)
        return self._unauthentic_secret_request(action)

    @rule(previous_action=init_initiators, payment_identifier=integers())
    def secret_request_with_wrong_payment_id(self, previous_action,
                                             payment_identifier):
        assume(
            payment_identifier != previous_action.transfer.payment_identifier)
        self._assume_channel_opened(previous_action)
        transfer = deepcopy(previous_action.transfer)
        transfer.payment_identifier = payment_identifier
        action = self._receive_secret_request(transfer)
        self._unauthentic_secret_request(action)
Esempio n. 5
0
class PVectorEvolverBuilder(RuleBasedStateMachine):
    """
    Build a list and matching pvector evolver step-by-step.

    In each step in the state machine we do same operation on a list and
    on a pvector evolver, and then when we're done we compare the two.
    """
    sequences = Bundle("evolver_sequences")

    @rule(target=sequences, start=PVectorAndLists)
    def initial_value(self, start):
        """
        Some initial values generated by a hypothesis strategy.
        """
        l, pv = start
        return EvolverItem(original_list=l,
                           original_pvector=pv,
                           current_list=l[:],
                           current_evolver=pv.evolver())

    @rule(item=sequences)
    def append(self, item):
        """
        Append an item to the pair of sequences.
        """
        obj = TestObject()
        item.current_list.append(obj)
        item.current_evolver.append(obj)

    @rule(start=sequences, end=sequences)
    def extend(self, start, end):
        """
        Extend a pair of sequences with another pair of sequences.
        """
        # compare() has O(N**2) behavior, so don't want too-large lists:
        assume(len(start.current_list) + len(end.current_list) < 50)
        start.current_evolver.extend(end.current_list)
        start.current_list.extend(end.current_list)

    @rule(item=sequences, choice=st.choices())
    def delete(self, item, choice):
        """
        Remove an item from the sequences.
        """
        assume(item.current_list)
        i = choice(range(len(item.current_list)))
        del item.current_list[i]
        del item.current_evolver[i]

    @rule(item=sequences, choice=st.choices())
    def setitem(self, item, choice):
        """
        Overwrite an item in the sequence using ``__setitem__``.
        """
        assume(item.current_list)
        i = choice(range(len(item.current_list)))
        obj = TestObject()
        item.current_list[i] = obj
        item.current_evolver[i] = obj

    @rule(item=sequences, choice=st.choices())
    def set(self, item, choice):
        """
        Overwrite an item in the sequence using ``set``.
        """
        assume(item.current_list)
        i = choice(range(len(item.current_list)))
        obj = TestObject()
        item.current_list[i] = obj
        item.current_evolver.set(i, obj)

    @rule(item=sequences)
    def compare(self, item):
        """
        The list and pvector evolver must match.
        """
        item.current_evolver.is_dirty()
        # compare() has O(N**2) behavior, so don't want too-large lists:
        assume(len(item.current_list) < 50)
        # original object unmodified
        assert item.original_list == item.original_pvector
        # evolver matches:
        for i in range(len(item.current_evolver)):
            assert item.current_list[i] == item.current_evolver[i]
        # persistent version matches
        assert_equal(item.current_list, item.current_evolver.persistent())
        # original object still unmodified
        assert item.original_list == item.original_pvector
Esempio n. 6
0
class SyncMachine(RuleBasedStateMachine):
    Status = Bundle('status')
    Storage = Bundle('storage')

    @rule(target=Storage,
          flaky_etags=st.booleans(),
          null_etag_on_upload=st.booleans())
    def newstorage(self, flaky_etags, null_etag_on_upload):
        s = MemoryStorage()
        if flaky_etags:

            def get(href):
                old_etag, item = s.items[href]
                etag = _random_string()
                s.items[href] = etag, item
                return item, etag

            s.get = get

        if null_etag_on_upload:
            _old_upload = s.upload
            _old_update = s.update
            s.upload = lambda item: (_old_upload(item)[0], 'NULL')
            s.update = lambda h, i, e: _old_update(h, i, e) and 'NULL'

        return s

    @rule(s=Storage, read_only=st.booleans())
    def is_read_only(self, s, read_only):
        assume(s.read_only != read_only)
        s.read_only = read_only

    @rule(s=Storage)
    def actions_fail(self, s):
        s.upload = action_failure
        s.update = action_failure
        s.delete = action_failure

    @rule(s=Storage)
    def none_as_etag(self, s):
        _old_upload = s.upload
        _old_update = s.update

        def upload(item):
            return _old_upload(item)[0], None

        def update(href, item, etag):
            _old_update(href, item, etag)

        s.upload = upload
        s.update = update

    @rule(target=Status)
    def newstatus(self):
        return {}

    @rule(storage=Storage, uid=uid_strategy, etag=st.text())
    def upload(self, storage, uid, etag):
        item = Item('UID:{}'.format(uid))
        storage.items[uid] = (etag, item)

    @rule(storage=Storage, href=st.text())
    def delete(self, storage, href):
        assume(storage.items.pop(href, None))

    @rule(status=Status,
          a=Storage,
          b=Storage,
          force_delete=st.booleans(),
          conflict_resolution=st.one_of(
              (st.just('a wins'), st.just('b wins'))),
          with_error_callback=st.booleans(),
          partial_sync=st.one_of(
              (st.just('ignore'), st.just('revert'), st.just('error'))))
    def sync(self, status, a, b, force_delete, conflict_resolution,
             with_error_callback, partial_sync):
        assume(a is not b)
        old_items_a = items(a)
        old_items_b = items(b)

        a.instance_name = 'a'
        b.instance_name = 'b'

        errors = []

        if with_error_callback:
            error_callback = errors.append
        else:
            error_callback = None

        try:
            # If one storage is read-only, double-sync because changes don't
            # get reverted immediately.
            for _ in range(2 if a.read_only or b.read_only else 1):
                sync(a,
                     b,
                     status,
                     force_delete=force_delete,
                     conflict_resolution=conflict_resolution,
                     error_callback=error_callback,
                     partial_sync=partial_sync)

            for e in errors:
                raise e
        except PartialSync:
            assert partial_sync == 'error'
        except ActionIntentionallyFailed:
            pass
        except BothReadOnly:
            assert a.read_only and b.read_only
            assume(False)
        except StorageEmpty:
            if force_delete:
                raise
            else:
                assert not list(a.list()) or not list(b.list())
        else:
            items_a = items(a)
            items_b = items(b)

            assert items_a == items_b or partial_sync == 'ignore'
            assert items_a == old_items_a or not a.read_only
            assert items_b == old_items_b or not b.read_only

            assert set(a.items) | set(b.items) == set(status) or \
                partial_sync == 'ignore'
Esempio n. 7
0
@composite
def secret(draw):
    return draw(builds(random_secret))


def event_types_match(events, *expected_types):
    return Counter([type(event)
                    for event in events]) == Counter(expected_types)


def transferred_amount(state):
    return 0 if not state.balance_proof else state.balance_proof.transferred_amount


partners = Bundle('partners')
# shared bundle of ChainStateStateMachine and all mixin classes


class ChainStateStateMachine(RuleBasedStateMachine):
    def __init__(self, address=None):
        self.address = address or factories.make_address()
        self.replay_path = False
        self.address_to_channel = dict()
        self.address_to_privkey = dict()

        self.our_previous_deposit = defaultdict(int)
        self.partner_previous_deposit = defaultdict(int)
        self.our_previous_transferred = defaultdict(int)
        self.partner_previous_transferred = defaultdict(int)
        self.our_previous_unclaimed = defaultdict(int)
Esempio n. 8
0
)


class IntAdder(RuleBasedStateMachine):
    pass


IntAdder.define_rule(
    targets=(u'ints',), function=lambda self, x: x, arguments={
        u'x': integers()
    }
)

IntAdder.define_rule(
    targets=(u'ints',), function=lambda self, x, y: x, arguments={
        u'x': integers(), u'y': Bundle(u'ints'),
    }
)


class ChoosingMachine(GenericStateMachine):

    def steps(self):
        return choices()

    def execute_step(self, choices):
        choices([1, 2, 3])


with Settings(max_examples=10):
    TestChoosingMachine = ChoosingMachine.TestCase
Esempio n. 9
0
class AuthStateMachine(RuleBasedStateMachine):
    """
    State machine for auth flows

    How to understand this code:

    This code exercises our social auth APIs, which is basically a graph of nodes and edges that the user traverses.
    You can understand the bundles defined below to be the nodes and the methods of this class to be the edges.

    If you add a new state to the auth flows, create a new bundle to represent that state and define
    methods to define transitions into and (optionally) out of that state.
    """

    # pylint: disable=too-many-instance-attributes

    ConfirmationSentAuthStates = Bundle("confirmation-sent")
    ConfirmationRedeemedAuthStates = Bundle("confirmation-redeemed")
    RegisterExtraDetailsAuthStates = Bundle("register-details-extra")

    LoginPasswordAuthStates = Bundle("login-password")
    LoginPasswordAbandonedAuthStates = Bundle("login-password-abandoned")

    recaptcha_patcher = patch(
        "authentication.views.requests.post",
        return_value=MockResponse(content='{"success": true}',
                                  status_code=status.HTTP_200_OK),
    )
    email_send_patcher = patch("mail.verification_api.send_verification_email",
                               autospec=True)
    courseware_api_patcher = patch(
        "authentication.pipeline.user.courseware_api")
    courseware_tasks_patcher = patch(
        "authentication.pipeline.user.courseware_tasks")

    def __init__(self):
        """Setup the machine"""
        super().__init__()
        # wrap the execution in a django transaction, similar to django's TestCase
        self.atomic = transaction.atomic()
        self.atomic.__enter__()

        # wrap the execution in a patch()
        self.mock_email_send = self.email_send_patcher.start()
        self.mock_courseware_api = self.courseware_api_patcher.start()
        self.mock_courseware_tasks = self.courseware_tasks_patcher.start()

        # django test client
        self.client = Client()

        # shared data
        self.email = fake.email()
        self.user = None
        self.password = "******"

        # track whether we've hit an action that starts a flow or not
        self.flow_started = False

    def teardown(self):
        """Cleanup from a run"""
        # clear the mailbox
        del mail.outbox[:]

        # stop the patches
        self.email_send_patcher.stop()
        self.courseware_api_patcher.stop()
        self.courseware_tasks_patcher.stop()

        # end the transaction with a rollback to cleanup any state
        transaction.set_rollback(True)
        self.atomic.__exit__(None, None, None)

    def create_existing_user(self):
        """Create an existing user"""
        self.user = UserFactory.create(email=self.email)
        self.user.set_password(self.password)
        self.user.save()
        UserSocialAuthFactory.create(user=self.user,
                                     provider=EmailAuth.name,
                                     uid=self.user.email)

    @rule(
        target=ConfirmationSentAuthStates,
        recaptcha_enabled=st.sampled_from([True, False]),
    )
    @precondition(lambda self: not self.flow_started)
    def register_email_not_exists(self, recaptcha_enabled):
        """Register email not exists"""
        self.flow_started = True

        with ExitStack() as stack:
            mock_recaptcha_success = None
            if recaptcha_enabled:
                mock_recaptcha_success = stack.enter_context(
                    self.recaptcha_patcher)
                stack.enter_context(
                    override_settings(**{"RECAPTCHA_SITE_KEY": "fake"}))
            result = assert_api_call(
                self.client,
                "psa-register-email",
                {
                    "flow": SocialAuthState.FLOW_REGISTER,
                    "email": self.email,
                    **({
                        "recaptcha": "fake"
                    } if recaptcha_enabled else {}),
                },
                {
                    "flow": SocialAuthState.FLOW_REGISTER,
                    "partial_token": None,
                    "state": SocialAuthState.STATE_REGISTER_CONFIRM_SENT,
                },
            )
            self.mock_email_send.assert_called_once()
            if mock_recaptcha_success:
                mock_recaptcha_success.assert_called_once()
            return result

    @rule(target=LoginPasswordAuthStates,
          recaptcha_enabled=st.sampled_from([True, False]))
    @precondition(lambda self: not self.flow_started)
    def register_email_exists(self, recaptcha_enabled):
        """Register email exists"""
        self.flow_started = True
        self.create_existing_user()

        with ExitStack() as stack:
            mock_recaptcha_success = None
            if recaptcha_enabled:
                mock_recaptcha_success = stack.enter_context(
                    self.recaptcha_patcher)
                stack.enter_context(
                    override_settings(**{"RECAPTCHA_SITE_KEY": "fake"}))

            result = assert_api_call(
                self.client,
                "psa-register-email",
                {
                    "flow": SocialAuthState.FLOW_REGISTER,
                    "email": self.email,
                    "next": NEXT_URL,
                    **({
                        "recaptcha": "fake"
                    } if recaptcha_enabled else {}),
                },
                {
                    "flow": SocialAuthState.FLOW_REGISTER,
                    "state": SocialAuthState.STATE_LOGIN_PASSWORD,
                    "errors": ["Password is required to login"],
                },
            )
            self.mock_email_send.assert_not_called()
            if mock_recaptcha_success:
                mock_recaptcha_success.assert_called_once()
            return result

    @rule()
    @precondition(lambda self: not self.flow_started)
    def register_email_not_exists_with_recaptcha_invalid(self):
        """Yield a function for this step"""
        self.flow_started = True
        with patch(
                "authentication.views.requests.post",
                return_value=MockResponse(
                    content=
                    '{"success": false, "error-codes": ["bad-request"]}',
                    status_code=status.HTTP_200_OK,
                ),
        ) as mock_recaptcha_failure, override_settings(
                **{"RECAPTCHA_SITE_KEY": "fakse"}):
            assert_api_call(
                self.client,
                "psa-register-email",
                {
                    "flow": SocialAuthState.FLOW_REGISTER,
                    "email": NEW_EMAIL,
                    "recaptcha": "fake",
                },
                {
                    "error-codes": ["bad-request"],
                    "success": False
                },
                expect_status=status.HTTP_400_BAD_REQUEST,
                use_defaults=False,
            )
            mock_recaptcha_failure.assert_called_once()
            self.mock_email_send.assert_not_called()

    @rule()
    @precondition(lambda self: not self.flow_started)
    def login_email_not_exists(self):
        """Login for an email that doesn't exist"""
        self.flow_started = True
        assert_api_call(
            self.client,
            "psa-login-email",
            {
                "flow": SocialAuthState.FLOW_LOGIN,
                "email": self.email
            },
            {
                "field_errors": {
                    "email": "Couldn't find your account"
                },
                "flow": SocialAuthState.FLOW_LOGIN,
                "partial_token": None,
                "state": SocialAuthState.STATE_REGISTER_REQUIRED,
            },
        )
        assert User.objects.filter(email=self.email).exists() is False

    @rule(target=LoginPasswordAuthStates)
    @precondition(lambda self: not self.flow_started)
    def login_email_exists(self):
        """Login with a user that exists"""
        self.flow_started = True
        self.create_existing_user()

        return assert_api_call(
            self.client,
            "psa-login-email",
            {
                "flow": SocialAuthState.FLOW_LOGIN,
                "email": self.user.email,
                "next": NEXT_URL,
            },
            {
                "flow": SocialAuthState.FLOW_LOGIN,
                "state": SocialAuthState.STATE_LOGIN_PASSWORD,
                "extra_data": {
                    "name": self.user.name
                },
            },
        )

    @rule(
        target=LoginPasswordAbandonedAuthStates,
        auth_state=consumes(RegisterExtraDetailsAuthStates),
    )
    @precondition(lambda self: self.flow_started)
    def login_email_abandoned(self, auth_state):  # pylint: disable=unused-argument
        """Login with a user that abandoned the register flow"""
        # NOTE: This works by "consuming" an extra details auth state,
        #       but discarding the state and starting a new login.
        #       It then re-targets the new state into the extra details again.
        auth_state = None  # assign None to ensure no accidental usage here

        return assert_api_call(
            self.client,
            "psa-login-email",
            {
                "flow": SocialAuthState.FLOW_LOGIN,
                "email": self.user.email,
                "next": NEXT_URL,
            },
            {
                "flow": SocialAuthState.FLOW_LOGIN,
                "state": SocialAuthState.STATE_LOGIN_PASSWORD,
                "extra_data": {
                    "name": self.user.name
                },
            },
        )

    @rule(
        target=RegisterExtraDetailsAuthStates,
        auth_state=consumes(LoginPasswordAbandonedAuthStates),
    )
    def login_password_abandoned(self, auth_state):
        """Login with an abandoned registration user"""
        return assert_api_call(
            self.client,
            "psa-login-password",
            {
                "flow": auth_state["flow"],
                "partial_token": auth_state["partial_token"],
                "password": self.password,
            },
            {
                "flow": auth_state["flow"],
                "state": SocialAuthState.STATE_REGISTER_EXTRA_DETAILS,
            },
        )

    @rule(auth_state=consumes(LoginPasswordAuthStates))
    def login_password_valid(self, auth_state):
        """Login with a valid password"""
        assert_api_call(
            self.client,
            "psa-login-password",
            {
                "flow": auth_state["flow"],
                "partial_token": auth_state["partial_token"],
                "password": self.password,
            },
            {
                "flow": auth_state["flow"],
                "redirect_url": NEXT_URL,
                "partial_token": None,
                "state": SocialAuthState.STATE_SUCCESS,
            },
            expect_authenticated=True,
        )

    @rule(target=LoginPasswordAuthStates,
          auth_state=consumes(LoginPasswordAuthStates))
    def login_password_invalid(self, auth_state):
        """Login with an invalid password"""
        return assert_api_call(
            self.client,
            "psa-login-password",
            {
                "flow": auth_state["flow"],
                "partial_token": auth_state["partial_token"],
                "password": "******",
            },
            {
                "field_errors": {
                    "password":
                    "******"
                },
                "flow": auth_state["flow"],
                "state": SocialAuthState.STATE_ERROR,
            },
        )

    @rule(
        auth_state=consumes(LoginPasswordAuthStates),
        verify_exports=st.sampled_from([True, False]),
    )
    def login_password_user_inactive(self, auth_state, verify_exports):
        """Login for an inactive user"""
        self.user.is_active = False
        self.user.save()

        cm = export_check_response("100_success") if verify_exports else noop()

        with cm:
            assert_api_call(
                self.client,
                "psa-login-password",
                {
                    "flow": auth_state["flow"],
                    "partial_token": auth_state["partial_token"],
                    "password": self.password,
                },
                {
                    "flow": auth_state["flow"],
                    "redirect_url": NEXT_URL,
                    "partial_token": None,
                    "state": SocialAuthState.STATE_SUCCESS,
                },
                expect_authenticated=True,
            )

    @rule(auth_state=consumes(LoginPasswordAuthStates))
    def login_password_exports_temporary_error(self, auth_state):
        """Login for a user who hasn't been OFAC verified yet"""
        with override_settings(**get_cybersource_test_settings()), patch(
                "authentication.pipeline.compliance.api.verify_user_with_exports",
                side_effect=Exception(
                    "register_details_export_temporary_error"),
        ):
            assert_api_call(
                self.client,
                "psa-login-password",
                {
                    "flow": auth_state["flow"],
                    "partial_token": auth_state["partial_token"],
                    "password": self.password,
                },
                {
                    "flow":
                    auth_state["flow"],
                    "partial_token":
                    None,
                    "state":
                    SocialAuthState.STATE_ERROR_TEMPORARY,
                    "errors": [
                        "Unable to register at this time, please try again later"
                    ],
                },
            )

    @rule(
        target=ConfirmationRedeemedAuthStates,
        auth_state=consumes(ConfirmationSentAuthStates),
    )
    def redeem_confirmation_code(self, auth_state):
        """Redeem a registration confirmation code"""
        _, _, code, partial_token = self.mock_email_send.call_args[0]
        return assert_api_call(
            self.client,
            "psa-register-confirm",
            {
                "flow": auth_state["flow"],
                "verification_code": code.code,
                "partial_token": partial_token,
            },
            {
                "flow": auth_state["flow"],
                "state": SocialAuthState.STATE_REGISTER_DETAILS,
            },
        )

    @rule(auth_state=consumes(ConfirmationRedeemedAuthStates))
    def redeem_confirmation_code_twice(self, auth_state):
        """Redeeming a code twice should fail"""
        _, _, code, partial_token = self.mock_email_send.call_args[0]
        assert_api_call(
            self.client,
            "psa-register-confirm",
            {
                "flow": auth_state["flow"],
                "verification_code": code.code,
                "partial_token": partial_token,
            },
            {
                "errors": [],
                "flow": auth_state["flow"],
                "redirect_url": None,
                "partial_token": None,
                "state": SocialAuthState.STATE_INVALID_LINK,
            },
        )

    @rule(auth_state=consumes(ConfirmationRedeemedAuthStates))
    def redeem_confirmation_code_twice_existing_user(self, auth_state):
        """Redeeming a code twice with an existing user should fail with existing account state"""
        _, _, code, partial_token = self.mock_email_send.call_args[0]
        self.create_existing_user()
        assert_api_call(
            self.client,
            "psa-register-confirm",
            {
                "flow": auth_state["flow"],
                "verification_code": code.code,
                "partial_token": partial_token,
            },
            {
                "errors": [],
                "flow": auth_state["flow"],
                "redirect_url": None,
                "partial_token": None,
                "state": SocialAuthState.STATE_EXISTING_ACCOUNT,
            },
        )

    @rule(
        target=RegisterExtraDetailsAuthStates,
        auth_state=consumes(ConfirmationRedeemedAuthStates),
    )
    def register_details(self, auth_state):
        """Complete the register confirmation details page"""
        result = assert_api_call(
            self.client,
            "psa-register-details",
            {
                "flow": auth_state["flow"],
                "partial_token": auth_state["partial_token"],
                "password": self.password,
                "name": "Sally Smith",
                "legal_address": {
                    "first_name": "Sally",
                    "last_name": "Smith",
                    "street_address": ["Main Street"],
                    "country": "US",
                    "state_or_territory": "US-CO",
                    "city": "Boulder",
                    "postal_code": "02183",
                },
            },
            {
                "flow": auth_state["flow"],
                "state": SocialAuthState.STATE_REGISTER_EXTRA_DETAILS,
            },
        )
        self.user = User.objects.get(email=self.email)
        return result

    @rule(
        target=RegisterExtraDetailsAuthStates,
        auth_state=consumes(ConfirmationRedeemedAuthStates),
    )
    def register_details_export_success(self, auth_state):
        """Complete the register confirmation details page with exports enabled"""
        with export_check_response("100_success"):
            result = assert_api_call(
                self.client,
                "psa-register-details",
                {
                    "flow": auth_state["flow"],
                    "partial_token": auth_state["partial_token"],
                    "password": self.password,
                    "name": "Sally Smith",
                    "legal_address": {
                        "first_name": "Sally",
                        "last_name": "Smith",
                        "street_address": ["Main Street"],
                        "country": "US",
                        "state_or_territory": "US-CO",
                        "city": "Boulder",
                        "postal_code": "02183",
                    },
                },
                {
                    "flow": auth_state["flow"],
                    "state": SocialAuthState.STATE_REGISTER_EXTRA_DETAILS,
                },
            )
            assert ExportsInquiryLog.objects.filter(
                user__email=self.email).exists()
            assert (ExportsInquiryLog.objects.get(
                user__email=self.email).computed_result == RESULT_SUCCESS)
            assert len(mail.outbox) == 0

            self.user = User.objects.get(email=self.email)
            return result

    @rule(auth_state=consumes(ConfirmationRedeemedAuthStates))
    def register_details_export_reject(self, auth_state):
        """Complete the register confirmation details page with exports enabled"""
        with export_check_response("700_reject"):
            assert_api_call(
                self.client,
                "psa-register-details",
                {
                    "flow": auth_state["flow"],
                    "partial_token": auth_state["partial_token"],
                    "password": self.password,
                    "name": "Sally Smith",
                    "legal_address": {
                        "first_name": "Sally",
                        "last_name": "Smith",
                        "street_address": ["Main Street"],
                        "country": "US",
                        "state_or_territory": "US-CO",
                        "city": "Boulder",
                        "postal_code": "02183",
                    },
                },
                {
                    "flow": auth_state["flow"],
                    "partial_token": None,
                    "errors": ["Error code: CS_700"],
                    "state": SocialAuthState.STATE_USER_BLOCKED,
                },
            )
            assert ExportsInquiryLog.objects.filter(
                user__email=self.email).exists()
            assert (ExportsInquiryLog.objects.get(
                user__email=self.email).computed_result == RESULT_DENIED)
            assert len(mail.outbox) == 1

    @rule(auth_state=consumes(ConfirmationRedeemedAuthStates))
    def register_details_export_temporary_error(self, auth_state):
        """Complete the register confirmation details page with exports raising a temporary error"""
        with override_settings(**get_cybersource_test_settings()), patch(
                "authentication.pipeline.compliance.api.verify_user_with_exports",
                side_effect=Exception(
                    "register_details_export_temporary_error"),
        ):
            assert_api_call(
                self.client,
                "psa-register-details",
                {
                    "flow": auth_state["flow"],
                    "partial_token": auth_state["partial_token"],
                    "password": self.password,
                    "name": "Sally Smith",
                    "legal_address": {
                        "first_name": "Sally",
                        "last_name": "Smith",
                        "street_address": ["Main Street"],
                        "country": "US",
                        "state_or_territory": "US-CO",
                        "city": "Boulder",
                        "postal_code": "02183",
                    },
                },
                {
                    "flow":
                    auth_state["flow"],
                    "partial_token":
                    None,
                    "errors": [
                        "Unable to register at this time, please try again later"
                    ],
                    "state":
                    SocialAuthState.STATE_ERROR_TEMPORARY,
                },
            )
            assert not ExportsInquiryLog.objects.filter(
                user__email=self.email).exists()
            assert len(mail.outbox) == 0

    @rule(auth_state=consumes(RegisterExtraDetailsAuthStates))
    def register_user_extra_details(self, auth_state):
        """Complete the user's extra details"""
        assert_api_call(
            Client(),
            "psa-register-extra",
            {
                "flow": auth_state["flow"],
                "partial_token": auth_state["partial_token"],
                "gender": "f",
                "birth_year": "2000",
                "company": "MIT",
                "job_title": "QA Manager",
            },
            {
                "flow": auth_state["flow"],
                "state": SocialAuthState.STATE_SUCCESS,
                "partial_token": None,
            },
            expect_authenticated=True,
        )
Esempio n. 10
0
def test_rule_deprecation_bundle_by_name():
    Bundle("k")
    with pytest.raises(InvalidArgument):
        rule(target="k")
    class TrustchainValidate(RuleBasedStateMachine):
        NonRevokedAdminUsers = Bundle("admin users")
        NonRevokedOtherUsers = Bundle("other users")
        RevokedUsers = Bundle("revoked users")

        def next_user_id(self):
            nonlocal name_count
            name_count += 1
            return UserID(f"user{name_count}")

        def next_device_id(self, user_id=None):
            nonlocal name_count
            user_id = user_id or self.next_user_id()
            name_count += 1
            return user_id.to_device_id(DeviceName(f"dev{name_count}"))

        def new_user_and_device(self, is_admin, certifier_id, certifier_key):
            device_id = self.next_device_id()

            local_device = local_device_factory(device_id, org=coolorg)
            self.local_devices[device_id] = local_device

            user = UserCertificateContent(
                author=certifier_id,
                timestamp=pendulum_now(),
                user_id=local_device.user_id,
                human_handle=local_device.human_handle,
                public_key=local_device.public_key,
                profile=UserProfile.ADMIN
                if is_admin else UserProfile.STANDARD,
            )
            self.users_content[device_id.user_id] = user
            self.users_certifs[device_id.user_id] = user.dump_and_sign(
                certifier_key)

            device = DeviceCertificateContent(
                author=certifier_id,
                timestamp=pendulum_now(),
                device_id=local_device.device_id,
                device_label=local_device.device_label,
                verify_key=local_device.verify_key,
            )
            self.devices_content[local_device.device_id] = device
            self.devices_certifs[
                local_device.device_id] = device.dump_and_sign(certifier_key)

            return device_id

        @initialize(target=NonRevokedAdminUsers)
        def init(self):
            caplog.clear()
            self.users_certifs = {}
            self.users_content = {}
            self.revoked_users_certifs = {}
            self.revoked_users_content = {}
            self.devices_certifs = {}
            self.devices_content = {}
            self.local_devices = {}
            device_id = self.new_user_and_device(
                is_admin=True,
                certifier_id=None,
                certifier_key=coolorg.root_signing_key)
            note(f"new device: {device_id}")
            return device_id.user_id

        def get_device(self, user_id, device_rand):
            user_devices = [
                device for device_id, device in self.local_devices.items()
                if device_id.user_id == user_id
            ]
            return user_devices[device_rand % len(user_devices)]

        @rule(
            target=NonRevokedAdminUsers,
            author_user=NonRevokedAdminUsers,
            author_device_rand=st.integers(min_value=0),
        )
        def new_admin_user(self, author_user, author_device_rand):
            author = self.get_device(author_user, author_device_rand)
            device_id = self.new_user_and_device(
                is_admin=True,
                certifier_id=author.device_id,
                certifier_key=author.signing_key)
            note(f"new device: {device_id} (author: {author.device_id})")
            return device_id.user_id

        @rule(
            target=NonRevokedOtherUsers,
            author_user=NonRevokedAdminUsers,
            author_device_rand=st.integers(min_value=0),
        )
        def new_non_admin_user(self, author_user, author_device_rand):
            author = self.get_device(author_user, author_device_rand)
            device_id = self.new_user_and_device(
                is_admin=False,
                certifier_id=author.device_id,
                certifier_key=author.signing_key)
            note(f"new device: {device_id} (author: {author.device_id})")
            return device_id.user_id

        @precondition(lambda self: len(
            [d for d in self.local_devices.values() if d.is_admin]) > 1)
        @rule(
            target=RevokedUsers,
            user=st.one_of(consumes(NonRevokedAdminUsers),
                           consumes(NonRevokedOtherUsers)),
            author_rand=st.integers(min_value=0),
        )
        def revoke_user(self, user, author_rand):
            possible_authors = [
                device for device_id, device in self.local_devices.items()
                if device_id.user_id != user
                and device.profile == UserProfile.ADMIN
            ]
            author = possible_authors[author_rand % len(possible_authors)]
            note(f"revoke user: {user} (author: {author.device_id})")
            revoked_user = RevokedUserCertificateContent(
                author=author.device_id,
                timestamp=pendulum_now(),
                user_id=user)
            self.revoked_users_content[user] = revoked_user
            self.revoked_users_certifs[user] = revoked_user.dump_and_sign(
                author.signing_key)
            return user

        @rule(
            user=st.one_of(NonRevokedAdminUsers, NonRevokedOtherUsers),
            author_user=NonRevokedAdminUsers,
            author_device_rand=st.integers(min_value=0),
        )
        def new_device(self, user, author_user, author_device_rand):
            author = self.get_device(author_user, author_device_rand)
            device_id = self.next_device_id(user)
            note(f"new device: {device_id} (author: {author.device_id})")
            local_device = local_device_factory(device_id, org=coolorg)
            device = DeviceCertificateContent(
                author=author.device_id,
                timestamp=pendulum_now(),
                device_id=local_device.device_id,
                device_label=local_device.device_label,
                verify_key=local_device.verify_key,
            )
            self.devices_content[local_device.device_id] = device
            self.devices_certifs[
                local_device.device_id] = device.dump_and_sign(
                    author.signing_key)

        @rule(user=st.one_of(NonRevokedAdminUsers, NonRevokedOtherUsers))
        def load_trustchain(self, user):
            ctx = TrustchainContext(coolorg.root_verify_key, 1)

            user_certif = next(
                certif for user_id, certif in self.users_certifs.items()
                if user_id == user)
            revoked_user_certif = next(
                (certif
                 for user_id, certif in self.revoked_users_certifs.items()
                 if user_id == user),
                None,
            )
            devices_certifs = [
                certif for device_id, certif in self.devices_certifs.items()
                if device_id.user_id == user
            ]
            user_content, revoked_user_content, devices_contents = ctx.load_user_and_devices(
                trustchain={
                    "users":
                    [certif for certif in self.users_certifs.values()],
                    "revoked_users":
                    [certif for certif in self.revoked_users_certifs.values()],
                    "devices":
                    [certif for certif in self.devices_certifs.values()],
                },
                user_certif=user_certif,
                revoked_user_certif=revoked_user_certif,
                devices_certifs=devices_certifs,
                expected_user_id=user,
            )

            expected_user_content = next(
                content for user_id, content in self.users_content.items()
                if user_id == user)
            expected_revoked_user_content = next(
                (content
                 for user_id, content in self.revoked_users_content.items()
                 if user_id == user),
                None,
            )
            expected_devices_contents = [
                content for device_id, content in self.devices_content.items()
                if device_id.user_id == user
            ]
            assert user_content == expected_user_content
            assert revoked_user_content == expected_revoked_user_content
            assert sorted(devices_contents,
                          key=lambda device: device.device_id) == sorted(
                              expected_devices_contents,
                              key=lambda device: device.device_id)
Esempio n. 12
0
def test_rule_deprecation_targets_and_target():
    k, v = Bundle("k"), Bundle("v")
    with pytest.raises(InvalidArgument):
        rule(targets=(k,), target=v)
Esempio n. 13
0
def test_deprecated_target_consumes_bundle():
    # It would be nicer to raise this error at runtime, but the internals make
    # this sadly impractical.  Most InvalidDefinition errors happen at, well,
    # definition-time already anyway, so it's not *worse* than the status quo.
    with validate_deprecation():
        rule(target=consumes(Bundle("b")))
Esempio n. 14
0
class StatefulTestFileGenerator(RuleBasedStateMachine):
    model_object = Bundle("model_object")

    name = Bundle("name")
    constants = Bundle("constants")
    fields = Bundle("fields")
    meta = Bundle("meta")

    @initialize()
    def remove_generated_file(self):
        if os.path.isfile(FILE):
            os.remove(FILE)

    @rule(target=name, name=fake_class_name())
    def add_name(self, name):
        assume(not model_exists(name))
        return name

    @rule(target=constants, constants=fake_constants())
    def add_constants(self, constants):
        return constants

    @rule(target=fields, fields=fake_fields_data())
    def add_fields(self, fields):
        return fields

    @rule(target=meta, meta=default_meta())
    def add_meta(self, meta):
        return meta

    @rule(
        target=model_object,
        name=consumes(name),
        constants=constants,
        fields=fields,
        meta=meta,
    )
    def add_model_object(self, name, constants, fields, meta):
        # Remove Duplicates Fields.
        for field in fields:
            constants.pop(field, None)

        try:
            django_model = get_django_model(name=name,
                                            constants=constants,
                                            fields=fields,
                                            meta=meta)
            model_object = get_model_object(django_model)
        except Exception as e:
            pytest.fail(e)
        else:
            return model_object

    @rule(original=consumes(model_object), tester=consumes(model_object))
    def assert_file_generator(self, original, tester):
        initial_file = []
        if os.path.isfile(FILE):
            event("assert_file_generator: File already exists.")
            with open(FILE, "r") as f:
                initial_file = f.read().splitlines()
        else:
            event("assert_file_generator: File doesn't exists.")

        file_generator_instance = FileGenerator(original, tester)

        # Test File exists.
        assert os.path.isfile(FILE)
        with open(FILE, "r") as f:
            modified_file = f.read().splitlines()

        # Test File has Header.
        assert all(modified_file[index] == line
                   for index, line in enumerate(FILE_HEADER.splitlines()))

        # Test initial data isn't modified.
        for line_number, line in enumerate(initial_file):
            assert line == modified_file[line_number]

        # Test names are corrects and attributes exists.
        appended_data = modified_file[len(initial_file):]
        pattern = r"assert (?P<original>\S+) == (?P<tester>\S+), assert_msg(.+)"
        for line in appended_data:
            assert_line = re.search(pattern, line)
            if assert_line:
                assert_line = {
                    model: {
                        "name": breadcrumb.split(".")[0],
                        "attr": ".".join(breadcrumb.split(".")[1:]),
                    }
                    for model, breadcrumb in assert_line.groupdict().items()
                }

                # Test names are corrects.
                assert assert_line["original"]["name"] == original._meta.name
                assert assert_line["tester"]["name"] == tester._meta.name

                # Test attributes compared are the same.
                assert assert_line["original"]["attr"] == assert_line[
                    "tester"]["attr"]

                # Test attributes exists.
                assert hasattrs(original, assert_line["original"]["attr"])
                assert hasattrs(tester, assert_line["tester"]["attr"])
        # Try retrieve generated functions.
        try:
            generated_functions = file_generator_instance.get_functions()
        except Exception as e:
            pytest.fail(e)

        # Test Module Import.
        try:
            module = sys.modules[MODULE]

            # Test Module has the Generated Class.
            assert hasattr(module, tester._meta.name)

            # Test Generated Class has the generated functions.
            generated_class = getattr(module, tester._meta.name)
            for generated_function in generated_functions.keys():
                assert hasattr(generated_class, generated_function)
        except KeyError as e:
            pytest.fail(e)
Esempio n. 15
0
class HypothesisSpec(RuleBasedStateMachine):
    def __init__(self):
        super(HypothesisSpec, self).__init__()
        self.database = None

    strategies = Bundle(u'strategy')
    strategy_tuples = Bundle(u'tuples')
    objects = Bundle(u'objects')
    basic_data = Bundle(u'basic')
    varied_floats = Bundle(u'varied_floats')

    def teardown(self):
        self.clear_database()

    @rule()
    def clear_database(self):
        if self.database is not None:
            self.database.close()
            self.database = None

    @rule()
    def set_database(self):
        self.teardown()
        self.database = ExampleDatabase()

    @rule(strat=strategies, r=integers(), max_shrinks=integers(0, 100))
    def find_constant_failure(self, strat, r, max_shrinks):
        with settings(
                verbosity=Verbosity.quiet,
                max_examples=1,
                min_satisfying_examples=0,
                database=self.database,
                max_shrinks=max_shrinks,
        ):

            @given(strat)
            @seed(r)
            def test(x):
                assert False

            try:
                test()
            except (AssertionError, FailedHealthCheck):
                pass

    @rule(strat=strategies,
          r=integers(),
          p=floats(0, 1),
          max_examples=integers(1, 10),
          max_shrinks=integers(1, 100))
    def find_weird_failure(self, strat, r, max_examples, p, max_shrinks):
        with settings(
                verbosity=Verbosity.quiet,
                max_examples=max_examples,
                min_satisfying_examples=0,
                database=self.database,
                max_shrinks=max_shrinks,
        ):

            @given(strat)
            @seed(r)
            def test(x):
                assert Random(hashlib.md5(
                    repr(x).encode(u'utf-8')).digest()).random() <= p

            try:
                test()
            except (AssertionError, FailedHealthCheck):
                pass

    @rule(target=strategies,
          spec=sampled_from((
              integers(),
              booleans(),
              floats(),
              complex_numbers(),
              fractions(),
              decimals(),
              text(),
              binary(),
              none(),
              tuples(),
          )))
    def strategy(self, spec):
        return spec

    @rule(target=strategies, values=lists(integers() | text(), min_size=1))
    def sampled_from_strategy(self, values):
        return sampled_from(values)

    @rule(target=strategies, spec=strategy_tuples)
    def strategy_for_tupes(self, spec):
        return tuples(*spec)

    @rule(target=strategies,
          source=strategies,
          level=integers(1, 10),
          mixer=text())
    def filtered_strategy(s, source, level, mixer):
        def is_good(x):
            return bool(
                Random(
                    hashlib.md5(
                        (mixer + repr(x)).encode(u'utf-8')).digest()).randint(
                            0, level))

        return source.filter(is_good)

    @rule(target=strategies, elements=strategies)
    def list_strategy(self, elements):
        return lists(elements)

    @rule(target=strategies, left=strategies, right=strategies)
    def or_strategy(self, left, right):
        return left | right

    @rule(target=varied_floats, source=floats())
    def float(self, source):
        return source

    @rule(target=varied_floats,
          source=varied_floats,
          offset=integers(-100, 100))
    def adjust_float(self, source, offset):
        return int_to_float(clamp(0, float_to_int(source) + offset, 2**64 - 1))

    @rule(target=strategies, left=varied_floats, right=varied_floats)
    def float_range(self, left, right):
        for f in (math.isnan, math.isinf):
            for x in (left, right):
                assume(not f(x))
        left, right = sorted((left, right))
        assert left <= right
        return floats(left, right)

    @rule(target=strategies,
          source=strategies,
          result1=strategies,
          result2=strategies,
          mixer=text(),
          p=floats(0, 1))
    def flatmapped_strategy(self, source, result1, result2, mixer, p):
        assume(result1 is not result2)

        def do_map(value):
            rep = repr(value)
            random = Random(
                hashlib.md5((mixer + rep).encode(u'utf-8')).digest())
            if random.random() <= p:
                return result1
            else:
                return result2

        return source.flatmap(do_map)

    @rule(target=strategies, value=objects)
    def just_strategy(self, value):
        return just(value)

    @rule(target=strategy_tuples, source=strategies)
    def single_tuple(self, source):
        return (source, )

    @rule(target=strategy_tuples, left=strategy_tuples, right=strategy_tuples)
    def cat_tuples(self, left, right):
        return left + right

    @rule(target=objects, strat=strategies, data=data())
    def get_example(self, strat, data):
        data.draw(strat)

    @rule(target=strategies, left=integers(), right=integers())
    def integer_range(self, left, right):
        left, right = sorted((left, right))
        return integers(left, right)

    @rule(strat=strategies)
    def repr_is_good(self, strat):
        assert u' at 0x' not in repr(strat)
Esempio n. 16
0
class GraphCompare(RuleBasedStateMachine):
    def __init__(self):
        super().__init__()
        # stores the corresponding node/tensor v1, v2, ... as they are
        # created via the unit test (through `create_node` or `fuse_nodes`)
        # `Node` is the naive implementation of `Tensor` that we are checking
        # against
        self.node_list = []  # type: List[Tuple[Node, Tensor]]
        self.str_to_tensor_op = {"add": add, "multiply": multiply}
        self.str_to_node_op = {"add": _add, "multiply": _multiply}
        self.raised = False

    nodes = Bundle("nodes")

    @rule(target=nodes, value=st.floats(-10, 10), constant=st.booleans())
    def create_node(self, value, constant):
        n = Node(value, constant=constant)
        t = Tensor(value, constant=constant)
        self.node_list.append((n, t))
        return n, t

    @rule(
        target=nodes,
        a=nodes,
        b=nodes,
        op=st.sampled_from(["add", "multiply"]),
        constant=st.booleans(),
    )
    def fuse_nodes(self, a, b, op, constant):
        """
        Combine any pair of nodes (tensors) using either addition or multiplication, producing
        a new node (tensor)"""
        n_a, t_a = a
        n_b, t_b = b
        n_op = self.str_to_node_op[op]
        t_op = self.str_to_tensor_op[op]
        out = (n_op(n_a, n_b,
                    constant=constant), t_op(t_a, t_b, constant=constant))
        self.node_list.append(out)
        return out

    @rule(items=nodes, clear_graph=st.booleans())
    def null_gradients(self, items, clear_graph):
        """
        Invoke `null_gradients` on the computational graph (naive and mygrad), with
        `clear_graph=True` specified optionally.
        """
        n, t = items
        n.null_gradients(clear_graph=clear_graph)
        t.null_gradients(clear_graph=clear_graph)

    @rule(items=nodes)
    def clear_graph(self, items):
        """
        Invoke `clear_graph` on the computational graph (naive and mygrad)
        """
        n, t = items
        n.clear_graph()
        t.clear_graph()

    @rule(items=nodes, grad=st.floats(-10, 10))
    def backprop(self, items, grad):
        """
        Invoke `backward(grad)` on the computational graph (naive and mygrad) from a randomly-selected
        node in the computational graph and using a randomly-generated gradient value.

        An exception should be raised if `clear_graph` is invoked anywhere prior to the invoking node.
        """
        n, t = items
        n.null_gradients(clear_graph=False)
        t.null_gradients(clear_graph=False)
        try:
            n.backward(grad, terminal_node=True)
        except Exception:
            with raises(Exception):
                t.backward(grad)
            self.raised = True
        else:
            t.backward(grad)

    @precondition(lambda self: not self.raised)
    @invariant()
    def all_agree(self):
        """
        Ensure that all corresponding nodes/tensors have matching data and gradients
        across the respective graphs.
        """
        for num, (n, t) in enumerate(self.node_list):
            assert bool(n._ops) is bool(t._ops), _node_ID_str(num)
            assert_equal(n.data, t.data, err_msg=_node_ID_str(num))
            if n.grad is None or t.grad is None:
                assert n.grad is t.grad, _node_ID_str(num)
            else:
                assert_allclose(
                    actual=t.grad,
                    desired=n.grad,
                    atol=1e-5,
                    rtol=1e-5,
                    err_msg=_node_ID_str(num),
                )
            assert not t._accum_ops, _node_ID_str(num)
Esempio n. 17
0
class SetModel(RuleBasedStateMachine):
    intsets = Bundle('IntSets')
    values = Bundle('values')

    @rule(target=values, i=integers_in_range)
    def int_value(self, i):
        return i

    @rule(target=values, i=integers_in_range, imp=intsets)
    def endpoint_value(self, i, imp):
        if len(imp[0]) > 0:
            return imp[0][-1]
        else:
            return i

    @rule(target=values, i=integers_in_range, imp=intsets)
    def startpoint_value(self, i, imp):
        if len(imp[0]) > 0:
            return imp[0][0]
        else:
            return i

    @rule(target=intsets, bounds=short_intervals)
    def build_interval(self, bounds):
        return (IntSet.interval(*bounds), list(range(*bounds)))

    @rule(target=intsets, v=values)
    def single_value(self, v):
        return (IntSet.single(v), [v])

    @rule(target=intsets, v=values)
    def adjacent_values(self, v):
        assume(v + 1 <= 2 ** 64)
        return (IntSet.interval(v, v + 2), [v, v + 1])

    @rule(target=intsets, v=values)
    def three_adjacent_values(self, v):
        assume(v + 2 <= 2 ** 64)
        return (IntSet.interval(v, v + 3), [v, v + 1, v + 2])

    @rule(target=intsets, v=values)
    def three_adjacent_values_with_hole(self, v):
        assume(v + 2 <= 2 ** 64)
        return (IntSet.single(v) | IntSet.single(v + 2), [v, v + 2])

    @rule(target=intsets, x=intsets, y=intsets)
    def union(self, x, y):
        return (x[0] | y[0], sorted(set(x[1] + y[1])))

    @rule(target=intsets, x=intsets, y=intsets)
    def intersect(self, x, y):
        return (x[0] & y[0], sorted(set(x[1]) & set(y[1])))

    @rule(target=intsets, x=intsets, y=intsets)
    def subtract(self, x, y):
        return (x[0] - y[0], sorted(set(x[1]) - set(y[1])))

    @rule(target=intsets, x=intsets, i=values)
    def insert(self, x, i):
        return (x[0].insert(i), sorted(set(x[1] + [i])))

    @rule(target=intsets, x=intsets, i=values)
    def discard(self, x, i):
        return (x[0].discard(i), sorted(set(x[1]) - set([i])))

    @rule(target=intsets, source=intsets, bounds=intervals)
    def restrict(self, source, bounds):
        return (
            source[0].restrict(*bounds), [x for x in source[1]
                                          if bounds[0] <= x < bounds[1]])

    @rule(target=intsets, x=intsets)
    def peel_left(self, x):
        if len(x[0]) == 0:
            return x
        return self.restrict(x, (x[0][0], x[0][-1] + 1))

    @rule(target=intsets, x=intsets)
    def peel_right(self, x):
        if len(x[0]) == 0:
            return x
        return self.restrict(x, (x[0][0], x[0][-1]))

    @rule(x=intsets, y=intsets)
    def validate_order(self, x, y):
        assert (x[0] <= y[0]) == (x[1] <= y[1])

    @rule(x=intsets, y=intsets)
    def validate_equality(self, x, y):
        assert (x[0] == y[0]) == (x[1] == y[1])

    @rule(source=intsets)
    def validate(self, source):
        assert list(source[0]) == source[1]
        assert len(source[0]) == len(source[1])
        for i in range(-len(source[0]), len(source[0])):
            assert source[0][i] == source[1][i]
        if len(source[0]) > 0:
            for v in source[1]:
                assert source[0][0] <= v <= source[0][-1]
Esempio n. 18
0
class QCircuitMachine(RuleBasedStateMachine):
    """Build a Hypothesis rule based state machine for constructing, transpiling
    and simulating a series of random QuantumCircuits.

    Build circuits with up to QISKIT_RANDOM_QUBITS qubits, apply a random
    selection of gates from qiskit.extensions.standard with randomly selected
    qargs, cargs, and parameters. At random intervals, transpile the circuit for
    a random backend with a random optimization level and simulate both the
    initial and the transpiled circuits to verify that their counts are the
    same.

    """

    qubits = Bundle('qubits')
    clbits = Bundle('clbits')

    backend = Aer.get_backend('qasm_simulator')
    max_qubits = int(backend.configuration().n_qubits / 2)

    def __init__(self):
        super().__init__()
        self.qc = QuantumCircuit()

    @precondition(lambda self: len(self.qc.qubits) < self.max_qubits)
    @rule(target=qubits,
          n=st.integers(min_value=1, max_value=max_qubits))
    def add_qreg(self, n):
        """Adds a new variable sized qreg to the circuit, up to max_qubits."""
        n = min(n, self.max_qubits - len(self.qc.qubits))
        qreg = QuantumRegister(n)
        self.qc.add_register(qreg)
        return multiple(*list(qreg))

    @rule(target=clbits,
          n=st.integers(1, 5))
    def add_creg(self, n):
        """Add a new variable sized creg to the circuit."""
        creg = ClassicalRegister(n)
        self.qc.add_register(creg)
        return multiple(*list(creg))

    # Gates of various shapes

    @rule(gate=st.sampled_from(oneQ_gates),
          qarg=qubits)
    def add_1q_gate(self, gate, qarg):
        """Append a random 1q gate on a random qubit."""
        self.qc.append(gate(), [qarg], [])

    @rule(gate=st.sampled_from(twoQ_gates),
          qargs=st.lists(qubits, max_size=2, min_size=2, unique=True))
    def add_2q_gate(self, gate, qargs):
        """Append a random 2q gate across two random qubits."""
        self.qc.append(gate(), qargs)

    @rule(gate=st.sampled_from(threeQ_gates),
          qargs=st.lists(qubits, max_size=3, min_size=3, unique=True))
    def add_3q_gate(self, gate, qargs):
        """Append a random 3q gate across three random qubits."""
        self.qc.append(gate(), qargs)

    @rule(gate=st.sampled_from(oneQ_oneP_gates),
          qarg=qubits,
          param=st.floats(allow_nan=False, allow_infinity=False))
    def add_1q1p_gate(self, gate, qarg, param):
        """Append a random 1q gate with 1 random float parameter."""
        self.qc.append(gate(param), [qarg])

    @rule(gate=st.sampled_from(oneQ_twoP_gates),
          qarg=qubits,
          params=st.lists(
              st.floats(allow_nan=False, allow_infinity=False),
              min_size=2, max_size=2))
    def add_1q2p_gate(self, gate, qarg, params):
        """Append a random 1q gate with 2 random float parameters."""
        self.qc.append(gate(*params), [qarg])

    @rule(gate=st.sampled_from(oneQ_threeP_gates),
          qarg=qubits,
          params=st.lists(
              st.floats(allow_nan=False, allow_infinity=False),
              min_size=3, max_size=3))
    def add_1q3p_gate(self, gate, qarg, params):
        """Append a random 1q gate with 3 random float parameters."""
        self.qc.append(gate(*params), [qarg])

    @rule(gate=st.sampled_from(twoQ_oneP_gates),
          qargs=st.lists(qubits, max_size=2, min_size=2, unique=True),
          param=st.floats(allow_nan=False, allow_infinity=False))
    def add_2q1p_gate(self, gate, qargs, param):
        """Append a random 2q gate with 1 random float parameter."""
        self.qc.append(gate(param), qargs)

    @rule(gate=st.sampled_from(twoQ_threeP_gates),
          qargs=st.lists(qubits, max_size=2, min_size=2, unique=True),
          params=st.lists(
              st.floats(allow_nan=False, allow_infinity=False),
              min_size=3, max_size=3))
    def add_2q3p_gate(self, gate, qargs, params):
        """Append a random 2q gate with 3 random float parameters."""
        self.qc.append(gate(*params), qargs)

    @rule(gate=st.sampled_from(oneQ_oneC_gates),
          qarg=qubits,
          carg=clbits)
    def add_1q1c_gate(self, gate, qarg, carg):
        """Append a random 1q, 1c gate."""
        self.qc.append(gate(), [qarg], [carg])

    @rule(gate=st.sampled_from(variadic_gates),
          qargs=st.lists(qubits, min_size=1, unique=True))
    def add_variQ_gate(self, gate, qargs):
        """Append a gate with a variable number of qargs."""
        self.qc.append(gate(len(qargs)), qargs)

    @precondition(lambda self: len(self.qc.data) > 0)
    @rule(carg=clbits,
          data=st.data())
    def add_c_if_last_gate(self, carg, data):
        """Modify the last gate to be conditional on a classical register."""
        creg = carg.register
        val = data.draw(st.integers(min_value=0, max_value=2**len(creg)-1))

        last_gate = self.qc.data[-1]

        # Work around for https://github.com/Qiskit/qiskit-terra/issues/2567
        assume(not isinstance(last_gate[0], Measure) or creg != last_gate[2][0].register)

        last_gate[0].c_if(creg, val)

    # Properties to check

    @invariant()
    def qasm(self):
        """After each circuit operation, it should be possible to build QASM."""
        self.qc.qasm()

    @precondition(lambda self: any(isinstance(d[0], Measure) for d in self.qc.data))
    @rule(
        backend=st.one_of(
            st.none(),
            st.sampled_from(mock_backends)),
        opt_level=st.one_of(
            st.none(),
            st.integers(min_value=0, max_value=3)))
    def equivalent_transpile(self, backend, opt_level):
        """Simulate, transpile and simulate the present circuit. Verify that the
        counts are not significantly different before and after transpilation.

        """

        assume(backend is None or backend.configuration().n_qubits >= len(self.qc.qubits))

        shots = 4096

        aer_qasm_simulator = self.backend
        aer_counts = execute(self.qc, backend=self.backend,
                             shots=shots).result().get_counts()

        try:
            xpiled_qc = transpile(self.qc, backend=backend, optimization_level=opt_level)
        except Exception as e:
            failed_qasm = 'Exception caught during transpilation of circuit: \n{}'.format(
                self.qc.qasm())
            raise RuntimeError(failed_qasm) from e

        xpiled_aer_counts = execute(xpiled_qc, backend=self.backend,
                                    shots=shots).result().get_counts()

        count_differences = dicts_almost_equal(aer_counts, xpiled_aer_counts, 0.05 * shots)

        assert count_differences == '', 'Counts not equivalent: {}\nFailing QASM: \n{}'.format(
            count_differences, self.qc.qasm())
Esempio n. 19
0
class HypothesisSpec(RuleBasedStateMachine):
    def __init__(self):
        super().__init__()
        self.database = None

    strategies = Bundle("strategy")
    strategy_tuples = Bundle("tuples")
    objects = Bundle("objects")
    basic_data = Bundle("basic")
    varied_floats = Bundle("varied_floats")

    def teardown(self):
        self.clear_database()

    @rule()
    def clear_database(self):
        if self.database is not None:
            self.database = None

    @rule()
    def set_database(self):
        self.teardown()
        self.database = ExampleDatabase()

    @rule(
        target=strategies,
        spec=sampled_from((
            integers(),
            booleans(),
            floats(),
            complex_numbers(),
            fractions(),
            decimals(),
            text(),
            binary(),
            none(),
            tuples(),
        )),
    )
    def strategy(self, spec):
        return spec

    @rule(target=strategies, values=lists(integers() | text(), min_size=1))
    def sampled_from_strategy(self, values):
        return sampled_from(values)

    @rule(target=strategies, spec=strategy_tuples)
    def strategy_for_tupes(self, spec):
        return tuples(*spec)

    @rule(target=strategies,
          source=strategies,
          level=integers(1, 10),
          mixer=text())
    def filtered_strategy(s, source, level, mixer):
        def is_good(x):
            seed = hashlib.sha384((mixer + repr(x)).encode()).digest()
            return bool(Random(seed).randint(0, level))

        return source.filter(is_good)

    @rule(target=strategies, elements=strategies)
    def list_strategy(self, elements):
        return lists(elements)

    @rule(target=strategies, left=strategies, right=strategies)
    def or_strategy(self, left, right):
        return left | right

    @rule(target=varied_floats, source=floats())
    def float(self, source):
        return source

    @rule(target=varied_floats,
          source=varied_floats,
          offset=integers(-100, 100))
    def adjust_float(self, source, offset):
        return int_to_float(clamp(0, float_to_int(source) + offset, 2**64 - 1))

    @rule(target=strategies, left=varied_floats, right=varied_floats)
    def float_range(self, left, right):
        assume(math.isfinite(left) and math.isfinite(right))
        left, right = sorted((left, right))
        assert left <= right
        # exclude deprecated case where left = 0.0 and right = -0.0
        assume(left or right
               or not (is_negative(right) and not is_negative(left)))
        return floats(left, right)

    @rule(
        target=strategies,
        source=strategies,
        result1=strategies,
        result2=strategies,
        mixer=text(),
        p=floats(0, 1),
    )
    def flatmapped_strategy(self, source, result1, result2, mixer, p):
        assume(result1 is not result2)

        def do_map(value):
            rep = repr(value)
            random = Random(hashlib.sha384((mixer + rep).encode()).digest())
            if random.random() <= p:
                return result1
            else:
                return result2

        return source.flatmap(do_map)

    @rule(target=strategies, value=objects)
    def just_strategy(self, value):
        return just(value)

    @rule(target=strategy_tuples, source=strategies)
    def single_tuple(self, source):
        return (source, )

    @rule(target=strategy_tuples, left=strategy_tuples, right=strategy_tuples)
    def cat_tuples(self, left, right):
        return left + right

    @rule(target=objects, strat=strategies, data=data())
    def get_example(self, strat, data):
        data.draw(strat)

    @rule(target=strategies, left=integers(), right=integers())
    def integer_range(self, left, right):
        left, right = sorted((left, right))
        return integers(left, right)

    @rule(strat=strategies)
    def repr_is_good(self, strat):
        assert " at 0x" not in repr(strat)
    class FolderOperationsStateMachine(RuleBasedStateMachine):
        Files = Bundle("file")
        Folders = Bundle("folder")
        # Moving mountpoint
        NonRootFolder = Folders.filter(lambda x: not x.is_workspace())

        @initialize(target=Folders)
        def init(self):
            nonlocal tentative
            tentative += 1
            caplog.clear()

            async def _bootstrap(user_fs, mountpoint_manager):
                wid = await user_fs.workspace_create("w")
                await mountpoint_manager.mount_workspace(wid)

            self.mountpoint_service = mountpoint_service_factory(_bootstrap)

            self.folder_oracle = Path(tmpdir / f"oracle-test-{tentative}")
            self.folder_oracle.mkdir()
            oracle_root = self.folder_oracle / "root"
            oracle_root.mkdir()
            self.folder_oracle.chmod(
                0o500)  # Root oracle can no longer be removed this way
            (oracle_root / "w").mkdir()
            oracle_root.chmod(0o500)  # Also protect workspace from deletion

            return PathElement(f"/w", self.mountpoint_service.base_mountpoint,
                               oracle_root)

        def teardown(self):
            if hasattr(self, "mountpoint_service"):
                self.mountpoint_service.stop()

        @rule(target=Files, parent=Folders, name=st_entry_name)
        def touch(self, parent, name):
            path = parent / name

            expected_exc = None
            try:
                path.to_oracle().touch(exist_ok=False)
            except OSError as exc:
                expected_exc = exc

            with expect_raises(expected_exc):
                path.to_parsec().touch(exist_ok=False)

            return path

        @rule(target=Folders, parent=Folders, name=st_entry_name)
        def mkdir(self, parent, name):
            path = parent / name

            expected_exc = None
            try:
                path.to_oracle().mkdir(exist_ok=False)
            except OSError as exc:
                expected_exc = exc

            with expect_raises(expected_exc):
                path.to_parsec().mkdir(exist_ok=False)

            return path

        @rule(path=Files)
        def unlink(self, path):
            expected_exc = None
            try:
                path.to_oracle().unlink()
            except OSError as exc:
                expected_exc = exc

            with expect_raises(expected_exc):
                path.to_parsec().unlink()

        @rule(path=Files, length=st.integers(min_value=0, max_value=16))
        def resize(self, path, length):
            expected_exc = None
            try:
                os.truncate(path.to_oracle(), length)
            except OSError as exc:
                expected_exc = exc

            with expect_raises(expected_exc):
                os.truncate(path.to_parsec(), length)

        @rule(path=NonRootFolder)
        def rmdir(self, path):
            expected_exc = None
            try:
                path.to_oracle().rmdir()
            except OSError as exc:
                expected_exc = exc

            with expect_raises(expected_exc):
                path.to_parsec().rmdir()

        def _move(self, src, dst_parent, dst_name):
            dst = dst_parent / dst_name

            expected_exc = None
            try:
                oracle_rename(src.to_oracle(), dst.to_oracle())
            except OSError as exc:
                expected_exc = exc

            with expect_raises(expected_exc):
                src.to_parsec().rename(str(dst.to_parsec()))

            return dst

        @rule(target=Files,
              src=Files,
              dst_parent=Folders,
              dst_name=st_entry_name)
        def move_file(self, src, dst_parent, dst_name):
            return self._move(src, dst_parent, dst_name)

        @rule(target=Folders,
              src=NonRootFolder,
              dst_parent=Folders,
              dst_name=st_entry_name)
        def move_folder(self, src, dst_parent, dst_name):
            return self._move(src, dst_parent, dst_name)

        @rule(path=Folders)
        def iterdir(self, path):
            expected_exc = None
            try:
                expected_children = {
                    x.name
                    for x in path.to_oracle().iterdir()
                }
            except OSError as exc:
                expected_exc = exc

            with expect_raises(expected_exc):
                children = {x.name for x in path.to_parsec().iterdir()}

            if not expected_exc:
                assert children == expected_children
Esempio n. 21
0
    return Counter([type(event)
                    for event in events]) == Counter(expected_types)


def transferred_amount(state):
    return 0 if not state.balance_proof else state.balance_proof.transferred_amount


# use of hypothesis.stateful.multiple() breaks the failed-example code
# generation at the moment, this function is a temporary workaround
def unwrap_multiple(multiple_results):
    values = multiple_results.values
    return values[0] if len(values) == 1 else values


partners = Bundle("partners")
# shared bundle of ChainStateStateMachine and all mixin classes


class ChainStateStateMachine(RuleBasedStateMachine):
    def __init__(self, address=None):
        self.address = address or factories.make_address()
        self.replay_path = False
        self.address_to_channel = dict()
        self.address_to_privkey = dict()
        self.initial_number_of_channels = 1

        self.our_previous_deposit = defaultdict(int)
        self.partner_previous_deposit = defaultdict(int)
        self.our_previous_transferred = defaultdict(int)
        self.partner_previous_transferred = defaultdict(int)
Esempio n. 22
0

class IntAdder(RuleBasedStateMachine):
    pass


IntAdder.define_rule(targets=(u"ints", ),
                     function=lambda self, x: x,
                     arguments={u"x": integers()})

IntAdder.define_rule(
    targets=(u"ints", ),
    function=lambda self, x, y: x,
    arguments={
        u"x": integers(),
        u"y": Bundle(u"ints")
    },
)

TestDynamicMachine = DynamicMachine.TestCase
TestIntAdder = IntAdder.TestCase
TestPrecondition = PreconditionMachine.TestCase

for test_case in (TestDynamicMachine, TestIntAdder, TestPrecondition):
    test_case.settings = Settings(test_case.settings, max_examples=10)


def test_picks_up_settings_at_first_use_of_testcase():
    assert TestDynamicMachine.settings.max_examples == 10

Esempio n. 23
0
class MediatorMixin:
    def __init__(self):
        super().__init__()
        self.partner_to_balance_proof_data = dict()
        self.secrethash_to_secret = dict()
        self.waiting_for_unlock = dict()
        self.initial_number_of_channels = 2

    def _get_balance_proof_data(self, partner):
        if partner not in self.partner_to_balance_proof_data:
            partner_channel = self.address_to_channel[partner]
            self.partner_to_balance_proof_data[partner] = BalanceProofData(
                canonical_identifier=partner_channel.canonical_identifier)
        return self.partner_to_balance_proof_data[partner]

    def _update_balance_proof_data(self, partner, amount, expiration, secret):
        expected = self._get_balance_proof_data(partner)
        lock = HashTimeLockState(amount=amount,
                                 expiration=expiration,
                                 secrethash=sha256_secrethash(secret))
        expected.update(amount, lock)
        return expected

    init_mediators = Bundle("init_mediators")
    secret_requests = Bundle("secret_requests")
    unlocks = Bundle("unlocks")

    def _new_mediator_transfer(self, initiator_address, target_address,
                               payment_id, amount,
                               secret) -> LockedTransferSignedState:
        initiator_pkey = self.address_to_privkey[initiator_address]
        balance_proof_data = self._update_balance_proof_data(
            initiator_address, amount, self.block_number + 10, secret)
        self.secrethash_to_secret[sha256_secrethash(secret)] = secret

        return factories.create(
            factories.LockedTransferSignedStateProperties(
                **balance_proof_data.properties.__dict__,
                amount=amount,
                expiration=self.block_number + 10,
                payment_identifier=payment_id,
                secret=secret,
                initiator=initiator_address,
                target=target_address,
                token=self.token_id,
                sender=initiator_address,
                recipient=self.address,
                pkey=initiator_pkey,
                message_identifier=1,
            ))

    def _action_init_mediator(
            self, transfer: LockedTransferSignedState) -> ActionInitMediator:
        initiator_channel = self.address_to_channel[transfer.initiator]
        target_channel = self.address_to_channel[transfer.target]

        return ActionInitMediator(
            route_states=[factories.make_route_from_channel(target_channel)],
            from_hop=factories.make_hop_to_channel(initiator_channel),
            from_transfer=transfer,
            balance_proof=transfer.balance_proof,
            sender=transfer.balance_proof.sender,
        )

    @rule(
        target=init_mediators,
        initiator_address=partners,
        target_address=partners,
        payment_id=payment_id(),  # pylint: disable=no-value-for-parameter
        amount=integers(min_value=1, max_value=100),
        secret=secret(),  # pylint: disable=no-value-for-parameter
    )
    def valid_init_mediator(self, initiator_address, target_address,
                            payment_id, amount, secret):
        assume(initiator_address != target_address)

        transfer = self._new_mediator_transfer(initiator_address,
                                               target_address, payment_id,
                                               amount, secret)
        action = self._action_init_mediator(transfer)
        result = node.state_transition(self.chain_state, action)

        assert event_types_match(result.events, SendProcessed,
                                 SendLockedTransfer)

        return action

    @rule(target=secret_requests, previous_action=consumes(init_mediators))
    def valid_receive_secret_reveal(self, previous_action):
        secret = self.secrethash_to_secret[
            previous_action.from_transfer.lock.secrethash]
        sender = previous_action.from_transfer.target
        recipient = previous_action.from_transfer.initiator

        action = ReceiveSecretReveal(secret=secret, sender=sender)
        result = node.state_transition(self.chain_state, action)

        expiration = previous_action.from_transfer.lock.expiration
        in_time = self.block_number < expiration - DEFAULT_NUMBER_OF_BLOCK_CONFIRMATIONS
        still_waiting = self.block_number < expiration + DEFAULT_WAIT_BEFORE_LOCK_REMOVAL

        if in_time and self.channel_opened(sender) and self.channel_opened(
                recipient):
            assert event_types_match(result.events, SendSecretReveal,
                                     SendBalanceProof, EventUnlockSuccess)
            self.event("Unlock successful.")
            self.waiting_for_unlock[secret] = recipient
        elif still_waiting and self.channel_opened(recipient):
            assert event_types_match(result.events, SendSecretReveal)
            self.event("Unlock failed, secret revealed too late.")
        else:
            assert not result.events
            self.event(
                "ReceiveSecretRevealed after removal of lock - dropped.")
        return action

    @rule(previous_action=secret_requests)
    def replay_receive_secret_reveal(self, previous_action):
        result = node.state_transition(self.chain_state, previous_action)
        assert not result.events

    # pylint: disable=no-value-for-parameter
    @rule(previous_action=secret_requests, invalid_sender=address())
    # pylint: enable=no-value-for-parameter
    def replay_receive_secret_reveal_scrambled_sender(self, previous_action,
                                                      invalid_sender):
        action = ReceiveSecretReveal(previous_action.secret, invalid_sender)
        result = node.state_transition(self.chain_state, action)
        assert not result.events

    # pylint: disable=no-value-for-parameter
    @rule(previous_action=init_mediators, secret=secret())
    # pylint: enable=no-value-for-parameter
    def wrong_secret_receive_secret_reveal(self, previous_action, secret):
        sender = previous_action.from_transfer.target
        action = ReceiveSecretReveal(secret, sender)
        result = node.state_transition(self.chain_state, action)
        assert not result.events

    # pylint: disable=no-value-for-parameter
    @rule(target=secret_requests,
          previous_action=consumes(init_mediators),
          invalid_sender=address())
    # pylint: enable=no-value-for-parameter
    def wrong_address_receive_secret_reveal(self, previous_action,
                                            invalid_sender):
        secret = self.secrethash_to_secret[
            previous_action.from_transfer.lock.secrethash]
        invalid_action = ReceiveSecretReveal(secret, invalid_sender)
        result = node.state_transition(self.chain_state, invalid_action)
        assert not result.events

        valid_sender = previous_action.from_transfer.target
        valid_action = ReceiveSecretReveal(secret, valid_sender)
        return valid_action
Esempio n. 24
0
def test_rule_deprecation_targets_and_target():
    k, v = Bundle("k"), Bundle("v")
    rule(targets=(k, ), target=v)
Esempio n. 25
0
class PVectorBuilder(RuleBasedStateMachine):
    """
    Build a list and matching pvector step-by-step.

    In each step in the state machine we do same operation on a list and
    on a pvector, and then when we're done we compare the two.
    """
    sequences = Bundle("sequences")

    @rule(target=sequences, start=PVectorAndLists)
    def initial_value(self, start):
        """
        Some initial values generated by a hypothesis strategy.
        """
        return start

    @rule(target=sequences, former=sequences)
    @verify_inputs_unmodified
    def append(self, former):
        """
        Append an item to the pair of sequences.
        """
        l, pv = former
        obj = TestObject()
        l2 = l[:]
        l2.append(obj)
        return l2, pv.append(obj)

    @rule(target=sequences, start=sequences, end=sequences)
    @verify_inputs_unmodified
    def extend(self, start, end):
        """
        Extend a pair of sequences with another pair of sequences.
        """
        l, pv = start
        l2, pv2 = end
        # compare() has O(N**2) behavior, so don't want too-large lists:
        assume(len(l) + len(l2) < 50)
        l3 = l[:]
        l3.extend(l2)
        return l3, pv.extend(pv2)

    @rule(target=sequences, former=sequences, choice=st.choices())
    @verify_inputs_unmodified
    def remove(self, former, choice):
        """
        Remove an item from the sequences.
        """
        l, pv = former
        assume(l)
        l2 = l[:]
        i = choice(range(len(l)))
        del l2[i]
        return l2, pv.delete(i)

    @rule(target=sequences, former=sequences, choice=st.choices())
    @verify_inputs_unmodified
    def set(self, former, choice):
        """
        Overwrite an item in the sequence.
        """
        l, pv = former
        assume(l)
        l2 = l[:]
        i = choice(range(len(l)))
        obj = TestObject()
        l2[i] = obj
        return l2, pv.set(i, obj)

    @rule(target=sequences, former=sequences, choice=st.choices())
    @verify_inputs_unmodified
    def transform_set(self, former, choice):
        """
        Transform the sequence by setting value.
        """
        l, pv = former
        assume(l)
        l2 = l[:]
        i = choice(range(len(l)))
        obj = TestObject()
        l2[i] = obj
        return l2, pv.transform([i], obj)

    @rule(target=sequences, former=sequences, choice=st.choices())
    @verify_inputs_unmodified
    def transform_discard(self, former, choice):
        """
        Transform the sequence by discarding a value.
        """
        l, pv = former
        assume(l)
        l2 = l[:]
        i = choice(range(len(l)))
        del l2[i]
        return l2, pv.transform([i], discard)

    @rule(target=sequences, former=sequences, choice=st.choices())
    @verify_inputs_unmodified
    def subset(self, former, choice):
        """
        A subset of the previous sequence.
        """
        l, pv = former
        assume(l)
        i = choice(range(len(l)))
        j = choice(range(len(l)))
        return l[i:j], pv[i:j]

    @rule(pair=sequences)
    @verify_inputs_unmodified
    def compare(self, pair):
        """
        The list and pvector must match.
        """
        l, pv = pair
        # compare() has O(N**2) behavior, so don't want too-large lists:
        assume(len(l) < 50)
        assert_equal(l, pv)
Esempio n. 26
0
def test_rule_deprecation_bundle_by_name():
    Bundle("k")
    rule(target="k")
Esempio n. 27
0
 class NonTerminalMachine(RuleBasedStateMachine):
     @rule(value=Bundle(u'hi'))
     def bye(self, hi):
         pass
Esempio n. 28
0
def test_rule_non_bundle_target_oneof():
    k, v = Bundle("k"), Bundle("v")
    pattern = r".+ `one_of(a, b)` or `a | b` .+"
    with pytest.raises(InvalidArgument, match=pattern):
        rule(target=k | v)
Esempio n. 29
0
DynamicMachine.define_rule(targets=(), function=lambda self: 1, arguments={})


class IntAdder(RuleBasedStateMachine):
    pass


IntAdder.define_rule(targets=(u'ints', ),
                     function=lambda self, x: x,
                     arguments={u'x': integers()})

IntAdder.define_rule(targets=(u'ints', ),
                     function=lambda self, x, y: x,
                     arguments={
                         u'x': integers(),
                         u'y': Bundle(u'ints'),
                     })


class ChoosingMachine(GenericStateMachine):
    def steps(self):
        return choices()

    def execute_step(self, choices):
        choices([1, 2, 3])


with Settings(max_examples=10):
    TestChoosingMachine = ChoosingMachine.TestCase
    TestGoodSets = GoodSet.TestCase
    TestGivenLike = GivenLikeStateMachine.TestCase
Esempio n. 30
0
class ListSM(RuleBasedStateMachine):
    def __init__(self):
        self._var = ffi.new('struct list*')
        self._model = deque()
        self._model_contents = deque()
        super().__init__()

    first = alias_property()
    last = alias_property()

    def teardown(self):
        lib.list_clear(self._var)

    class Iterator(typing.Iterator[None]):
        cur = alias_property()
        prev = alias_property()

    def _make_iter(self, reverse: bool = False) -> ListSM.Iterator:
        var = ffi.new('struct list_it*')
        var.cur = self.last if reverse else self.first
        def it():
            while var.cur != ffi.NULL:
                yield  # yield None because we're mutable
                lib.list_step(var)
        it = it()
        class _It(ListSM.Iterator):
            _var = var
            def __next__(self) -> None:
                return next(it)
        return _It()

    def __iter__(self) -> ListSM.Iterator:
        return self._make_iter()

    def __reversed__(self) -> ListSM.Iterator:
        return self._make_iter(reverse=True)

    nodes = Bundle('Nodes')

    @rule(new_value=elements(), target=nodes)
    def insert_front(self, new_value):
        self._model_contents.appendleft(new_value)
        lib.list_insert_front(self._var, new_value)
        new_node = self.first
        assert new_node.value == new_value
        self._model.appendleft(new_node)
        return new_node

    @rule(new_value=elements(), target=nodes)
    def insert_back(self, new_value):
        self._model_contents.append(new_value)
        lib.list_insert_back(self._var, new_value)
        new_node = self.last
        assert new_node.value == new_value
        self._model.append(new_node)
        return new_node

    @rule(nodes=strategies.frozensets(consumes(nodes)), reverse=strategies.booleans())
    def remove_thru_iter(self, nodes, reverse):
        it = reversed(self) if reverse else iter(self)
        for _ in it:
            if it.cur in nodes:
                lib.list_remove(self._var, it._var)
        for n in nodes:
            i = self._model.index(n)
            del self._model_contents[i]
            del self._model[i]

    @invariant()
    def nodes_as_model(self):
        it = iter(self)
        nodes = [it.cur for _ in it]
        assert nodes == list(self._model)

    @invariant()
    def contents_as_model(self):
        it = iter(self)
        contents = [it.cur.value for _ in it]
        assert contents == list(self._model_contents)