class QCircuitMachine(RuleBasedStateMachine): """Build a Hypothesis rule based state machine for constructing, transpiling and simulating a series of random QuantumCircuits. Build circuits with up to QISKIT_RANDOM_QUBITS qubits, apply a random selection of gates from qiskit.circuit.library with randomly selected qargs, cargs, and parameters. At random intervals, transpile the circuit for a random backend with a random optimization level and simulate both the initial and the transpiled circuits to verify that their counts are the same. """ qubits = Bundle("qubits") clbits = Bundle("clbits") backend = Aer.get_backend("aer_simulator") max_qubits = int(backend.configuration().n_qubits / 2) # Limit reg generation for more interesting circuits max_qregs = 3 max_cregs = 3 def __init__(self): super().__init__() self.qc = QuantumCircuit() self.enable_variadic = bool(variadic_gates) @precondition(lambda self: len(self.qc.qubits) < self.max_qubits) @precondition(lambda self: len(self.qc.qregs) < self.max_qregs) @rule(target=qubits, n=st.integers(min_value=1, max_value=max_qubits)) def add_qreg(self, n): """Adds a new variable sized qreg to the circuit, up to max_qubits.""" n = min(n, self.max_qubits - len(self.qc.qubits)) qreg = QuantumRegister(n) self.qc.add_register(qreg) return multiple(*list(qreg)) @precondition(lambda self: len(self.qc.cregs) < self.max_cregs) @rule(target=clbits, n=st.integers(1, 5)) def add_creg(self, n): """Add a new variable sized creg to the circuit.""" creg = ClassicalRegister(n) self.qc.add_register(creg) return multiple(*list(creg)) # Gates of various shapes @precondition(lambda self: self.qc.num_qubits > 0 and self.qc.num_clbits > 0) @rule(n_arguments=st.sampled_from(sorted(BASE_INSTRUCTIONS.keys())), data=st.data()) def add_gate(self, n_arguments, data): """Append a random fixed gate to the circuit.""" n_qubits, n_clbits, n_params = n_arguments gate_class = data.draw(st.sampled_from(BASE_INSTRUCTIONS[n_qubits, n_clbits, n_params])) qubits = data.draw(st.lists(self.qubits, min_size=n_qubits, max_size=n_qubits, unique=True)) clbits = data.draw(st.lists(self.clbits, min_size=n_clbits, max_size=n_clbits, unique=True)) params = data.draw( st.lists( st.floats( allow_nan=False, allow_infinity=False, min_value=-10 * pi, max_value=10 * pi ), min_size=n_params, max_size=n_params, ) ) self.qc.append(gate_class(*params), qubits, clbits) @precondition(lambda self: self.enable_variadic) @rule(gate=st.sampled_from(variadic_gates), qargs=st.lists(qubits, min_size=1, unique=True)) def add_variQ_gate(self, gate, qargs): """Append a gate with a variable number of qargs.""" self.qc.append(gate(len(qargs)), qargs) @precondition(lambda self: len(self.qc.data) > 0) @rule(carg=clbits, data=st.data()) def add_c_if_last_gate(self, carg, data): """Modify the last gate to be conditional on a classical register.""" creg = self.qc.find_bit(carg).registers[0][0] val = data.draw(st.integers(min_value=0, max_value=2 ** len(creg) - 1)) last_gate = self.qc.data[-1] # Conditional instructions are not supported assume(isinstance(last_gate[0], Gate)) last_gate[0].c_if(creg, val) # Properties to check @invariant() def qasm(self): """After each circuit operation, it should be possible to build QASM.""" self.qc.qasm() @precondition(lambda self: any(isinstance(d[0], Measure) for d in self.qc.data)) @rule(conf=transpiler_conf()) def equivalent_transpile(self, conf): """Simulate, transpile and simulate the present circuit. Verify that the counts are not significantly different before and after transpilation. """ backend, opt_level, layout_method, routing_method, scheduling_method = conf assume(backend is None or backend.configuration().n_qubits >= len(self.qc.qubits)) print( f"Evaluating circuit at level {opt_level} on {backend} " f"using layout_method={layout_method} routing_method={routing_method} " f"and scheduling_method={scheduling_method}:\n{self.qc.qasm()}" ) shots = 4096 # Note that there's no transpilation here, which is why the gates are limited to only ones # that Aer supports natively. aer_counts = self.backend.run(self.qc, shots=shots).result().get_counts() try: xpiled_qc = transpile( self.qc, backend=backend, optimization_level=opt_level, layout_method=layout_method, routing_method=routing_method, scheduling_method=scheduling_method, ) except Exception as e: failed_qasm = "Exception caught during transpilation of circuit: \n{}".format( self.qc.qasm() ) raise RuntimeError(failed_qasm) from e xpiled_aer_counts = self.backend.run(xpiled_qc, shots=shots).result().get_counts() count_differences = dicts_almost_equal(aer_counts, xpiled_aer_counts, 0.05 * shots) assert ( count_differences == "" ), "Counts not equivalent: {}\nFailing QASM Input:\n{}\n\nFailing QASM Output:\n{}".format( count_differences, self.qc.qasm(), xpiled_qc.qasm() )
class verifyingstatemachine(RuleBasedStateMachine): """This defines the set of acceptable operations on a Mercurial repository using Hypothesis's RuleBasedStateMachine. The general concept is that we manage multiple repositories inside a repos/ directory in our temporary test location. Some of these are freshly inited, some are clones of the others. Our current working directory is always inside one of these repositories while the tests are running. Hypothesis then performs a series of operations against these repositories, including hg commands, generating contents and editing the .hgrc file. If these operations fail in unexpected ways or behave differently in different configurations of Mercurial, the test will fail and a minimized .t test file will be written to the hypothesis-generated directory to exhibit that failure. Operations are defined as methods with @rule() decorators. See the Hypothesis documentation at http://hypothesis.readthedocs.org/en/release/stateful.html for more details.""" # A bundle is a reusable collection of previously generated data which may # be provided as arguments to future operations. repos = Bundle('repos') paths = Bundle('paths') contents = Bundle('contents') branches = Bundle('branches') committimes = Bundle('committimes') def __init__(self): super(verifyingstatemachine, self).__init__() self.repodir = os.path.join(testtmp, "repos") if os.path.exists(self.repodir): shutil.rmtree(self.repodir) os.chdir(testtmp) self.log = [] self.failed = False self.configperrepo = {} self.all_extensions = set() self.non_skippable_extensions = set() self.mkdirp("repos") self.cd("repos") self.mkdirp("repo1") self.cd("repo1") self.hg("init") def teardown(self): """On teardown we clean up after ourselves as usual, but we also do some additional testing: We generate a .t file based on our test run using run-test.py -i to get the correct output. We then test it in a number of other configurations, verifying that each passes the same test.""" super(verifyingstatemachine, self).teardown() try: shutil.rmtree(self.repodir) except OSError: pass ttest = os.linesep.join(" " + l for l in self.log) os.chdir(testtmp) path = os.path.join(testtmp, "test-generated.t") with open(path, 'w') as o: o.write(ttest + os.linesep) with open(os.devnull, "w") as devnull: rewriter = subprocess.Popen( [runtests, "--local", "-i", path], stdin=subprocess.PIPE, stdout=devnull, stderr=devnull, ) rewriter.communicate("yes") with open(path, 'r') as i: ttest = i.read() e = None if not self.failed: try: output = subprocess.check_output( [runtests, path, "--local", "--pure"], stderr=subprocess.STDOUT) assert "Ran 1 test" in output, output for ext in (self.all_extensions - self.non_skippable_extensions): tf = os.path.join(testtmp, "test-generated-no-%s.t" % (ext, )) with open(tf, 'w') as o: for l in ttest.splitlines(): if l.startswith(" $ hg"): l = l.replace( "--config %s=" % (extensionconfigkey(ext), ), "") o.write(l + os.linesep) with open(tf, 'r') as r: t = r.read() assert ext not in t, t output = subprocess.check_output([ runtests, tf, "--local", ], stderr=subprocess.STDOUT) assert "Ran 1 test" in output, output except subprocess.CalledProcessError as e: note(e.output) if self.failed or e is not None: with open(savefile, "wb") as o: o.write(ttest) if e is not None: raise e def execute_step(self, step): try: return super(verifyingstatemachine, self).execute_step(step) except (HypothesisException, KeyboardInterrupt): raise except Exception: self.failed = True raise # Section: Basic commands. def mkdirp(self, path): if os.path.exists(path): return self.log.append("$ mkdir -p -- %s" % (pipes.quote(os.path.relpath(path)), )) os.makedirs(path) def cd(self, path): path = os.path.relpath(path) if path == ".": return os.chdir(path) self.log.append("$ cd -- %s" % (pipes.quote(path), )) def hg(self, *args): extra_flags = [] for key, value in self.config.items(): extra_flags.append("--config") extra_flags.append("%s=%s" % (key, value)) self.command("hg", *(tuple(extra_flags) + args)) def command(self, *args): self.log.append("$ " + ' '.join(map(pipes.quote, args))) subprocess.check_output(args, stderr=subprocess.STDOUT) # Section: Set up basic data # This section has no side effects but generates data that we will want # to use later. @rule(target=paths, source=st.lists(files, min_size=1).map(lambda l: os.path.join(*l))) def genpath(self, source): return source @rule(target=committimes, when=datetimes(min_year=1970, max_year=2038) | st.none()) def gentime(self, when): return when @rule(target=contents, content=st.one_of(st.binary(), st.text().map(lambda x: x.encode('utf-8')))) def gencontent(self, content): return content @rule( target=branches, name=safetext, ) def genbranch(self, name): return name @rule(target=paths, source=paths) def lowerpath(self, source): return source.lower() @rule(target=paths, source=paths) def upperpath(self, source): return source.upper() # Section: Basic path operations @rule(path=paths, content=contents) def writecontent(self, path, content): self.unadded_changes = True if os.path.isdir(path): return parent = os.path.dirname(path) if parent: try: self.mkdirp(parent) except OSError: # It may be the case that there is a regular file that has # previously been created that has the same name as an ancestor # of the current path. This will cause mkdirp to fail with this # error. We just turn this into a no-op in that case. return with open(path, 'wb') as o: o.write(content) self.log.append(("$ python -c 'import binascii; " "print(binascii.unhexlify(\"%s\"))' > %s") % ( binascii.hexlify(content), pipes.quote(path), )) @rule(path=paths) def addpath(self, path): if os.path.exists(path): self.hg("add", "--", path) @rule(path=paths) def forgetpath(self, path): if os.path.exists(path): with acceptableerrors("file is already untracked", ): self.hg("forget", "--", path) @rule(s=st.none() | st.integers(0, 100)) def addremove(self, s): args = ["addremove"] if s is not None: args.extend(["-s", str(s)]) self.hg(*args) @rule(path=paths) def removepath(self, path): if os.path.exists(path): with acceptableerrors( 'file is untracked', 'file has been marked for add', 'file is modified', ): self.hg("remove", "--", path) @rule( message=safetext, amend=st.booleans(), when=committimes, addremove=st.booleans(), secret=st.booleans(), close_branch=st.booleans(), ) def maybecommit(self, message, amend, when, addremove, secret, close_branch): command = ["commit"] errors = ["nothing changed"] if amend: errors.append("cannot amend public changesets") command.append("--amend") command.append("-m" + pipes.quote(message)) if secret: command.append("--secret") if close_branch: command.append("--close-branch") errors.append("can only close branch heads") if addremove: command.append("--addremove") if when is not None: if when.year == 1970: errors.append('negative date value') if when.year == 2038: errors.append('exceeds 32 bits') command.append("--date=%s" % (when.strftime('%Y-%m-%d %H:%M:%S %z'), )) with acceptableerrors(*errors): self.hg(*command) # Section: Repository management @property def currentrepo(self): return os.path.basename(os.getcwd()) @property def config(self): return self.configperrepo.setdefault(self.currentrepo, {}) @rule( target=repos, source=repos, name=reponames, ) def clone(self, source, name): if not os.path.exists(os.path.join("..", name)): self.cd("..") self.hg("clone", source, name) self.cd(name) return name @rule( target=repos, name=reponames, ) def fresh(self, name): if not os.path.exists(os.path.join("..", name)): self.cd("..") self.mkdirp(name) self.cd(name) self.hg("init") return name @rule(name=repos) def switch(self, name): self.cd(os.path.join("..", name)) assert self.currentrepo == name assert os.path.exists(".hg") @rule(target=repos) def origin(self): return "repo1" @rule() def pull(self, repo=repos): with acceptableerrors( "repository default not found", "repository is unrelated", ): self.hg("pull") @rule(newbranch=st.booleans()) def push(self, newbranch): with acceptableerrors( "default repository not configured", "no changes found", ): if newbranch: self.hg("push", "--new-branch") else: with acceptableerrors("creates new branches"): self.hg("push") # Section: Simple side effect free "check" operations @rule() def log(self): self.hg("log") @rule() def verify(self): self.hg("verify") @rule() def diff(self): self.hg("diff", "--nodates") @rule() def status(self): self.hg("status") @rule() def export(self): self.hg("export") # Section: Branch management @rule() def checkbranch(self): self.hg("branch") @rule(branch=branches) def switchbranch(self, branch): with acceptableerrors( 'cannot use an integer as a name', 'cannot be used in a name', 'a branch of the same name already exists', 'is reserved', ): self.hg("branch", "--", branch) @rule(branch=branches, clean=st.booleans()) def update(self, branch, clean): with acceptableerrors( 'unknown revision', 'parse error', ): if clean: self.hg("update", "-C", "--", branch) else: self.hg("update", "--", branch) # Section: Extension management def hasextension(self, extension): return extensionconfigkey(extension) in self.config def commandused(self, extension): assert extension in self.all_extensions self.non_skippable_extensions.add(extension) @rule(extension=extensions) def addextension(self, extension): self.all_extensions.add(extension) self.config[extensionconfigkey(extension)] = "" @rule(extension=extensions) def removeextension(self, extension): self.config.pop(extensionconfigkey(extension), None) # Section: Commands from the shelve extension @rule() @precondition(lambda self: self.hasextension("shelve")) def shelve(self): self.commandused("shelve") with acceptableerrors("nothing changed"): self.hg("shelve") @rule() @precondition(lambda self: self.hasextension("shelve")) def unshelve(self): self.commandused("shelve") with acceptableerrors("no shelved changes to apply"): self.hg("unshelve")
class DynamicMachine(RuleBasedStateMachine): @rule(value=Bundle(u'hi')) def test_stuff(x): pass
class InitiatorMixin: def __init__(self): super().__init__() self.used_secrets = set() self.processed_secret_requests = set() self.initiated = set() def _action_init_initiator(self, transfer: TransferDescriptionWithSecretState): channel = self.address_to_channel[transfer.target] if transfer.secrethash not in self.expected_expiry: self.expected_expiry[transfer.secrethash] = self.block_number + 10 return ActionInitInitiator( transfer, [factories.make_route_from_channel(channel)]) def _receive_secret_request(self, transfer: TransferDescriptionWithSecretState): secrethash = sha256(transfer.secret).digest() return ReceiveSecretRequest( payment_identifier=transfer.payment_identifier, amount=transfer.amount, expiration=self.expected_expiry[transfer.secrethash], secrethash=secrethash, sender=transfer.target, ) def _new_transfer_description(self, target, payment_id, amount, secret): self.used_secrets.add(secret) return TransferDescriptionWithSecretState( token_network_registry_address=self.token_network_registry_address, payment_identifier=payment_id, amount=amount, token_network_address=self.token_network_address, initiator=self.address, target=target, secret=secret, ) def _invalid_authentic_secret_request(self, previous, action): result = node.state_transition(self.chain_state, action) if action.secrethash in self.processed_secret_requests or self._is_removed( previous): assert not result.events else: self.processed_secret_requests.add(action.secrethash) def _unauthentic_secret_request(self, action): result = node.state_transition(self.chain_state, action) assert not result.events def _available_amount(self, partner_address): netting_channel = self.address_to_channel[partner_address] return channel.get_distributable(netting_channel.our_state, netting_channel.partner_state) def _assume_channel_opened(self, action): assume(self.channel_opened(action.transfer.target)) def _is_removed(self, action): expiry = self.expected_expiry[action.transfer.secrethash] return self.block_number >= expiry + DEFAULT_WAIT_BEFORE_LOCK_REMOVAL init_initiators = Bundle("init_initiators") @rule( target=init_initiators, partner=partners, payment_id=payment_id(), # pylint: disable=no-value-for-parameter amount=integers(min_value=1, max_value=100), secret=secret(), # pylint: disable=no-value-for-parameter ) def valid_init_initiator(self, partner, payment_id, amount, secret): assume(amount <= self._available_amount(partner)) assume(secret not in self.used_secrets) transfer = self._new_transfer_description(partner, payment_id, amount, secret) action = self._action_init_initiator(transfer) result = node.state_transition(self.chain_state, action) assert event_types_match(result.events, SendLockedTransfer) self.initiated.add(transfer.secret) self.expected_expiry[transfer.secrethash] = self.block_number + 10 return action @rule( partner=partners, payment_id=payment_id(), # pylint: disable=no-value-for-parameter excess_amount=integers(min_value=1), secret=secret(), # pylint: disable=no-value-for-parameter ) def exceeded_capacity_init_initiator(self, partner, payment_id, excess_amount, secret): amount = self._available_amount(partner) + excess_amount transfer = self._new_transfer_description(partner, payment_id, amount, secret) action = self._action_init_initiator(transfer) result = node.state_transition(self.chain_state, action) assert event_types_match(result.events, EventPaymentSentFailed) self.event("ActionInitInitiator failed: Amount exceeded") @rule( previous_action=init_initiators, partner=partners, payment_id=payment_id(), # pylint: disable=no-value-for-parameter amount=integers(min_value=1), ) def used_secret_init_initiator(self, previous_action, partner, payment_id, amount): assume(not self._is_removed(previous_action)) secret = previous_action.transfer.secret transfer = self._new_transfer_description(partner, payment_id, amount, secret) action = self._action_init_initiator(transfer) result = node.state_transition(self.chain_state, action) assert not result.events self.event("ActionInitInitiator failed: Secret already in use.") @rule(previous_action=init_initiators) def replay_init_initator(self, previous_action): assume(not self._is_removed(previous_action)) result = node.state_transition(self.chain_state, previous_action) assert not result.events @rule(previous_action=init_initiators) def valid_secret_request(self, previous_action): action = self._receive_secret_request(previous_action.transfer) self._assume_channel_opened(previous_action) result = node.state_transition(self.chain_state, action) if action.secrethash in self.processed_secret_requests: assert not result.events self.event( "Valid SecretRequest dropped due to previous invalid one.") elif self._is_removed(previous_action): assert not result.events self.event( "Otherwise valid SecretRequest dropped due to expired lock.") else: assert event_types_match(result.events, SendSecretReveal) self.event("Valid SecretRequest accepted.") self.processed_secret_requests.add(action.secrethash) @rule(previous_action=init_initiators, amount=integers()) def wrong_amount_secret_request(self, previous_action, amount): assume(amount != previous_action.transfer.amount) self._assume_channel_opened(previous_action) transfer = deepcopy(previous_action.transfer) transfer.amount = amount action = self._receive_secret_request(transfer) self._invalid_authentic_secret_request(previous_action, action) @rule( previous_action=init_initiators, secret=secret() # pylint: disable=no-value-for-parameter ) def secret_request_with_wrong_secrethash(self, previous_action, secret): assume( sha256_secrethash(secret) != sha256_secrethash( previous_action.transfer.secret)) self._assume_channel_opened(previous_action) transfer = deepcopy(previous_action.transfer) transfer.secret = secret action = self._receive_secret_request(transfer) return self._unauthentic_secret_request(action) @rule(previous_action=init_initiators, payment_identifier=integers()) def secret_request_with_wrong_payment_id(self, previous_action, payment_identifier): assume( payment_identifier != previous_action.transfer.payment_identifier) self._assume_channel_opened(previous_action) transfer = deepcopy(previous_action.transfer) transfer.payment_identifier = payment_identifier action = self._receive_secret_request(transfer) self._unauthentic_secret_request(action)
class PVectorEvolverBuilder(RuleBasedStateMachine): """ Build a list and matching pvector evolver step-by-step. In each step in the state machine we do same operation on a list and on a pvector evolver, and then when we're done we compare the two. """ sequences = Bundle("evolver_sequences") @rule(target=sequences, start=PVectorAndLists) def initial_value(self, start): """ Some initial values generated by a hypothesis strategy. """ l, pv = start return EvolverItem(original_list=l, original_pvector=pv, current_list=l[:], current_evolver=pv.evolver()) @rule(item=sequences) def append(self, item): """ Append an item to the pair of sequences. """ obj = TestObject() item.current_list.append(obj) item.current_evolver.append(obj) @rule(start=sequences, end=sequences) def extend(self, start, end): """ Extend a pair of sequences with another pair of sequences. """ # compare() has O(N**2) behavior, so don't want too-large lists: assume(len(start.current_list) + len(end.current_list) < 50) start.current_evolver.extend(end.current_list) start.current_list.extend(end.current_list) @rule(item=sequences, choice=st.choices()) def delete(self, item, choice): """ Remove an item from the sequences. """ assume(item.current_list) i = choice(range(len(item.current_list))) del item.current_list[i] del item.current_evolver[i] @rule(item=sequences, choice=st.choices()) def setitem(self, item, choice): """ Overwrite an item in the sequence using ``__setitem__``. """ assume(item.current_list) i = choice(range(len(item.current_list))) obj = TestObject() item.current_list[i] = obj item.current_evolver[i] = obj @rule(item=sequences, choice=st.choices()) def set(self, item, choice): """ Overwrite an item in the sequence using ``set``. """ assume(item.current_list) i = choice(range(len(item.current_list))) obj = TestObject() item.current_list[i] = obj item.current_evolver.set(i, obj) @rule(item=sequences) def compare(self, item): """ The list and pvector evolver must match. """ item.current_evolver.is_dirty() # compare() has O(N**2) behavior, so don't want too-large lists: assume(len(item.current_list) < 50) # original object unmodified assert item.original_list == item.original_pvector # evolver matches: for i in range(len(item.current_evolver)): assert item.current_list[i] == item.current_evolver[i] # persistent version matches assert_equal(item.current_list, item.current_evolver.persistent()) # original object still unmodified assert item.original_list == item.original_pvector
class SyncMachine(RuleBasedStateMachine): Status = Bundle('status') Storage = Bundle('storage') @rule(target=Storage, flaky_etags=st.booleans(), null_etag_on_upload=st.booleans()) def newstorage(self, flaky_etags, null_etag_on_upload): s = MemoryStorage() if flaky_etags: def get(href): old_etag, item = s.items[href] etag = _random_string() s.items[href] = etag, item return item, etag s.get = get if null_etag_on_upload: _old_upload = s.upload _old_update = s.update s.upload = lambda item: (_old_upload(item)[0], 'NULL') s.update = lambda h, i, e: _old_update(h, i, e) and 'NULL' return s @rule(s=Storage, read_only=st.booleans()) def is_read_only(self, s, read_only): assume(s.read_only != read_only) s.read_only = read_only @rule(s=Storage) def actions_fail(self, s): s.upload = action_failure s.update = action_failure s.delete = action_failure @rule(s=Storage) def none_as_etag(self, s): _old_upload = s.upload _old_update = s.update def upload(item): return _old_upload(item)[0], None def update(href, item, etag): _old_update(href, item, etag) s.upload = upload s.update = update @rule(target=Status) def newstatus(self): return {} @rule(storage=Storage, uid=uid_strategy, etag=st.text()) def upload(self, storage, uid, etag): item = Item('UID:{}'.format(uid)) storage.items[uid] = (etag, item) @rule(storage=Storage, href=st.text()) def delete(self, storage, href): assume(storage.items.pop(href, None)) @rule(status=Status, a=Storage, b=Storage, force_delete=st.booleans(), conflict_resolution=st.one_of( (st.just('a wins'), st.just('b wins'))), with_error_callback=st.booleans(), partial_sync=st.one_of( (st.just('ignore'), st.just('revert'), st.just('error')))) def sync(self, status, a, b, force_delete, conflict_resolution, with_error_callback, partial_sync): assume(a is not b) old_items_a = items(a) old_items_b = items(b) a.instance_name = 'a' b.instance_name = 'b' errors = [] if with_error_callback: error_callback = errors.append else: error_callback = None try: # If one storage is read-only, double-sync because changes don't # get reverted immediately. for _ in range(2 if a.read_only or b.read_only else 1): sync(a, b, status, force_delete=force_delete, conflict_resolution=conflict_resolution, error_callback=error_callback, partial_sync=partial_sync) for e in errors: raise e except PartialSync: assert partial_sync == 'error' except ActionIntentionallyFailed: pass except BothReadOnly: assert a.read_only and b.read_only assume(False) except StorageEmpty: if force_delete: raise else: assert not list(a.list()) or not list(b.list()) else: items_a = items(a) items_b = items(b) assert items_a == items_b or partial_sync == 'ignore' assert items_a == old_items_a or not a.read_only assert items_b == old_items_b or not b.read_only assert set(a.items) | set(b.items) == set(status) or \ partial_sync == 'ignore'
@composite def secret(draw): return draw(builds(random_secret)) def event_types_match(events, *expected_types): return Counter([type(event) for event in events]) == Counter(expected_types) def transferred_amount(state): return 0 if not state.balance_proof else state.balance_proof.transferred_amount partners = Bundle('partners') # shared bundle of ChainStateStateMachine and all mixin classes class ChainStateStateMachine(RuleBasedStateMachine): def __init__(self, address=None): self.address = address or factories.make_address() self.replay_path = False self.address_to_channel = dict() self.address_to_privkey = dict() self.our_previous_deposit = defaultdict(int) self.partner_previous_deposit = defaultdict(int) self.our_previous_transferred = defaultdict(int) self.partner_previous_transferred = defaultdict(int) self.our_previous_unclaimed = defaultdict(int)
) class IntAdder(RuleBasedStateMachine): pass IntAdder.define_rule( targets=(u'ints',), function=lambda self, x: x, arguments={ u'x': integers() } ) IntAdder.define_rule( targets=(u'ints',), function=lambda self, x, y: x, arguments={ u'x': integers(), u'y': Bundle(u'ints'), } ) class ChoosingMachine(GenericStateMachine): def steps(self): return choices() def execute_step(self, choices): choices([1, 2, 3]) with Settings(max_examples=10): TestChoosingMachine = ChoosingMachine.TestCase
class AuthStateMachine(RuleBasedStateMachine): """ State machine for auth flows How to understand this code: This code exercises our social auth APIs, which is basically a graph of nodes and edges that the user traverses. You can understand the bundles defined below to be the nodes and the methods of this class to be the edges. If you add a new state to the auth flows, create a new bundle to represent that state and define methods to define transitions into and (optionally) out of that state. """ # pylint: disable=too-many-instance-attributes ConfirmationSentAuthStates = Bundle("confirmation-sent") ConfirmationRedeemedAuthStates = Bundle("confirmation-redeemed") RegisterExtraDetailsAuthStates = Bundle("register-details-extra") LoginPasswordAuthStates = Bundle("login-password") LoginPasswordAbandonedAuthStates = Bundle("login-password-abandoned") recaptcha_patcher = patch( "authentication.views.requests.post", return_value=MockResponse(content='{"success": true}', status_code=status.HTTP_200_OK), ) email_send_patcher = patch("mail.verification_api.send_verification_email", autospec=True) courseware_api_patcher = patch( "authentication.pipeline.user.courseware_api") courseware_tasks_patcher = patch( "authentication.pipeline.user.courseware_tasks") def __init__(self): """Setup the machine""" super().__init__() # wrap the execution in a django transaction, similar to django's TestCase self.atomic = transaction.atomic() self.atomic.__enter__() # wrap the execution in a patch() self.mock_email_send = self.email_send_patcher.start() self.mock_courseware_api = self.courseware_api_patcher.start() self.mock_courseware_tasks = self.courseware_tasks_patcher.start() # django test client self.client = Client() # shared data self.email = fake.email() self.user = None self.password = "******" # track whether we've hit an action that starts a flow or not self.flow_started = False def teardown(self): """Cleanup from a run""" # clear the mailbox del mail.outbox[:] # stop the patches self.email_send_patcher.stop() self.courseware_api_patcher.stop() self.courseware_tasks_patcher.stop() # end the transaction with a rollback to cleanup any state transaction.set_rollback(True) self.atomic.__exit__(None, None, None) def create_existing_user(self): """Create an existing user""" self.user = UserFactory.create(email=self.email) self.user.set_password(self.password) self.user.save() UserSocialAuthFactory.create(user=self.user, provider=EmailAuth.name, uid=self.user.email) @rule( target=ConfirmationSentAuthStates, recaptcha_enabled=st.sampled_from([True, False]), ) @precondition(lambda self: not self.flow_started) def register_email_not_exists(self, recaptcha_enabled): """Register email not exists""" self.flow_started = True with ExitStack() as stack: mock_recaptcha_success = None if recaptcha_enabled: mock_recaptcha_success = stack.enter_context( self.recaptcha_patcher) stack.enter_context( override_settings(**{"RECAPTCHA_SITE_KEY": "fake"})) result = assert_api_call( self.client, "psa-register-email", { "flow": SocialAuthState.FLOW_REGISTER, "email": self.email, **({ "recaptcha": "fake" } if recaptcha_enabled else {}), }, { "flow": SocialAuthState.FLOW_REGISTER, "partial_token": None, "state": SocialAuthState.STATE_REGISTER_CONFIRM_SENT, }, ) self.mock_email_send.assert_called_once() if mock_recaptcha_success: mock_recaptcha_success.assert_called_once() return result @rule(target=LoginPasswordAuthStates, recaptcha_enabled=st.sampled_from([True, False])) @precondition(lambda self: not self.flow_started) def register_email_exists(self, recaptcha_enabled): """Register email exists""" self.flow_started = True self.create_existing_user() with ExitStack() as stack: mock_recaptcha_success = None if recaptcha_enabled: mock_recaptcha_success = stack.enter_context( self.recaptcha_patcher) stack.enter_context( override_settings(**{"RECAPTCHA_SITE_KEY": "fake"})) result = assert_api_call( self.client, "psa-register-email", { "flow": SocialAuthState.FLOW_REGISTER, "email": self.email, "next": NEXT_URL, **({ "recaptcha": "fake" } if recaptcha_enabled else {}), }, { "flow": SocialAuthState.FLOW_REGISTER, "state": SocialAuthState.STATE_LOGIN_PASSWORD, "errors": ["Password is required to login"], }, ) self.mock_email_send.assert_not_called() if mock_recaptcha_success: mock_recaptcha_success.assert_called_once() return result @rule() @precondition(lambda self: not self.flow_started) def register_email_not_exists_with_recaptcha_invalid(self): """Yield a function for this step""" self.flow_started = True with patch( "authentication.views.requests.post", return_value=MockResponse( content= '{"success": false, "error-codes": ["bad-request"]}', status_code=status.HTTP_200_OK, ), ) as mock_recaptcha_failure, override_settings( **{"RECAPTCHA_SITE_KEY": "fakse"}): assert_api_call( self.client, "psa-register-email", { "flow": SocialAuthState.FLOW_REGISTER, "email": NEW_EMAIL, "recaptcha": "fake", }, { "error-codes": ["bad-request"], "success": False }, expect_status=status.HTTP_400_BAD_REQUEST, use_defaults=False, ) mock_recaptcha_failure.assert_called_once() self.mock_email_send.assert_not_called() @rule() @precondition(lambda self: not self.flow_started) def login_email_not_exists(self): """Login for an email that doesn't exist""" self.flow_started = True assert_api_call( self.client, "psa-login-email", { "flow": SocialAuthState.FLOW_LOGIN, "email": self.email }, { "field_errors": { "email": "Couldn't find your account" }, "flow": SocialAuthState.FLOW_LOGIN, "partial_token": None, "state": SocialAuthState.STATE_REGISTER_REQUIRED, }, ) assert User.objects.filter(email=self.email).exists() is False @rule(target=LoginPasswordAuthStates) @precondition(lambda self: not self.flow_started) def login_email_exists(self): """Login with a user that exists""" self.flow_started = True self.create_existing_user() return assert_api_call( self.client, "psa-login-email", { "flow": SocialAuthState.FLOW_LOGIN, "email": self.user.email, "next": NEXT_URL, }, { "flow": SocialAuthState.FLOW_LOGIN, "state": SocialAuthState.STATE_LOGIN_PASSWORD, "extra_data": { "name": self.user.name }, }, ) @rule( target=LoginPasswordAbandonedAuthStates, auth_state=consumes(RegisterExtraDetailsAuthStates), ) @precondition(lambda self: self.flow_started) def login_email_abandoned(self, auth_state): # pylint: disable=unused-argument """Login with a user that abandoned the register flow""" # NOTE: This works by "consuming" an extra details auth state, # but discarding the state and starting a new login. # It then re-targets the new state into the extra details again. auth_state = None # assign None to ensure no accidental usage here return assert_api_call( self.client, "psa-login-email", { "flow": SocialAuthState.FLOW_LOGIN, "email": self.user.email, "next": NEXT_URL, }, { "flow": SocialAuthState.FLOW_LOGIN, "state": SocialAuthState.STATE_LOGIN_PASSWORD, "extra_data": { "name": self.user.name }, }, ) @rule( target=RegisterExtraDetailsAuthStates, auth_state=consumes(LoginPasswordAbandonedAuthStates), ) def login_password_abandoned(self, auth_state): """Login with an abandoned registration user""" return assert_api_call( self.client, "psa-login-password", { "flow": auth_state["flow"], "partial_token": auth_state["partial_token"], "password": self.password, }, { "flow": auth_state["flow"], "state": SocialAuthState.STATE_REGISTER_EXTRA_DETAILS, }, ) @rule(auth_state=consumes(LoginPasswordAuthStates)) def login_password_valid(self, auth_state): """Login with a valid password""" assert_api_call( self.client, "psa-login-password", { "flow": auth_state["flow"], "partial_token": auth_state["partial_token"], "password": self.password, }, { "flow": auth_state["flow"], "redirect_url": NEXT_URL, "partial_token": None, "state": SocialAuthState.STATE_SUCCESS, }, expect_authenticated=True, ) @rule(target=LoginPasswordAuthStates, auth_state=consumes(LoginPasswordAuthStates)) def login_password_invalid(self, auth_state): """Login with an invalid password""" return assert_api_call( self.client, "psa-login-password", { "flow": auth_state["flow"], "partial_token": auth_state["partial_token"], "password": "******", }, { "field_errors": { "password": "******" }, "flow": auth_state["flow"], "state": SocialAuthState.STATE_ERROR, }, ) @rule( auth_state=consumes(LoginPasswordAuthStates), verify_exports=st.sampled_from([True, False]), ) def login_password_user_inactive(self, auth_state, verify_exports): """Login for an inactive user""" self.user.is_active = False self.user.save() cm = export_check_response("100_success") if verify_exports else noop() with cm: assert_api_call( self.client, "psa-login-password", { "flow": auth_state["flow"], "partial_token": auth_state["partial_token"], "password": self.password, }, { "flow": auth_state["flow"], "redirect_url": NEXT_URL, "partial_token": None, "state": SocialAuthState.STATE_SUCCESS, }, expect_authenticated=True, ) @rule(auth_state=consumes(LoginPasswordAuthStates)) def login_password_exports_temporary_error(self, auth_state): """Login for a user who hasn't been OFAC verified yet""" with override_settings(**get_cybersource_test_settings()), patch( "authentication.pipeline.compliance.api.verify_user_with_exports", side_effect=Exception( "register_details_export_temporary_error"), ): assert_api_call( self.client, "psa-login-password", { "flow": auth_state["flow"], "partial_token": auth_state["partial_token"], "password": self.password, }, { "flow": auth_state["flow"], "partial_token": None, "state": SocialAuthState.STATE_ERROR_TEMPORARY, "errors": [ "Unable to register at this time, please try again later" ], }, ) @rule( target=ConfirmationRedeemedAuthStates, auth_state=consumes(ConfirmationSentAuthStates), ) def redeem_confirmation_code(self, auth_state): """Redeem a registration confirmation code""" _, _, code, partial_token = self.mock_email_send.call_args[0] return assert_api_call( self.client, "psa-register-confirm", { "flow": auth_state["flow"], "verification_code": code.code, "partial_token": partial_token, }, { "flow": auth_state["flow"], "state": SocialAuthState.STATE_REGISTER_DETAILS, }, ) @rule(auth_state=consumes(ConfirmationRedeemedAuthStates)) def redeem_confirmation_code_twice(self, auth_state): """Redeeming a code twice should fail""" _, _, code, partial_token = self.mock_email_send.call_args[0] assert_api_call( self.client, "psa-register-confirm", { "flow": auth_state["flow"], "verification_code": code.code, "partial_token": partial_token, }, { "errors": [], "flow": auth_state["flow"], "redirect_url": None, "partial_token": None, "state": SocialAuthState.STATE_INVALID_LINK, }, ) @rule(auth_state=consumes(ConfirmationRedeemedAuthStates)) def redeem_confirmation_code_twice_existing_user(self, auth_state): """Redeeming a code twice with an existing user should fail with existing account state""" _, _, code, partial_token = self.mock_email_send.call_args[0] self.create_existing_user() assert_api_call( self.client, "psa-register-confirm", { "flow": auth_state["flow"], "verification_code": code.code, "partial_token": partial_token, }, { "errors": [], "flow": auth_state["flow"], "redirect_url": None, "partial_token": None, "state": SocialAuthState.STATE_EXISTING_ACCOUNT, }, ) @rule( target=RegisterExtraDetailsAuthStates, auth_state=consumes(ConfirmationRedeemedAuthStates), ) def register_details(self, auth_state): """Complete the register confirmation details page""" result = assert_api_call( self.client, "psa-register-details", { "flow": auth_state["flow"], "partial_token": auth_state["partial_token"], "password": self.password, "name": "Sally Smith", "legal_address": { "first_name": "Sally", "last_name": "Smith", "street_address": ["Main Street"], "country": "US", "state_or_territory": "US-CO", "city": "Boulder", "postal_code": "02183", }, }, { "flow": auth_state["flow"], "state": SocialAuthState.STATE_REGISTER_EXTRA_DETAILS, }, ) self.user = User.objects.get(email=self.email) return result @rule( target=RegisterExtraDetailsAuthStates, auth_state=consumes(ConfirmationRedeemedAuthStates), ) def register_details_export_success(self, auth_state): """Complete the register confirmation details page with exports enabled""" with export_check_response("100_success"): result = assert_api_call( self.client, "psa-register-details", { "flow": auth_state["flow"], "partial_token": auth_state["partial_token"], "password": self.password, "name": "Sally Smith", "legal_address": { "first_name": "Sally", "last_name": "Smith", "street_address": ["Main Street"], "country": "US", "state_or_territory": "US-CO", "city": "Boulder", "postal_code": "02183", }, }, { "flow": auth_state["flow"], "state": SocialAuthState.STATE_REGISTER_EXTRA_DETAILS, }, ) assert ExportsInquiryLog.objects.filter( user__email=self.email).exists() assert (ExportsInquiryLog.objects.get( user__email=self.email).computed_result == RESULT_SUCCESS) assert len(mail.outbox) == 0 self.user = User.objects.get(email=self.email) return result @rule(auth_state=consumes(ConfirmationRedeemedAuthStates)) def register_details_export_reject(self, auth_state): """Complete the register confirmation details page with exports enabled""" with export_check_response("700_reject"): assert_api_call( self.client, "psa-register-details", { "flow": auth_state["flow"], "partial_token": auth_state["partial_token"], "password": self.password, "name": "Sally Smith", "legal_address": { "first_name": "Sally", "last_name": "Smith", "street_address": ["Main Street"], "country": "US", "state_or_territory": "US-CO", "city": "Boulder", "postal_code": "02183", }, }, { "flow": auth_state["flow"], "partial_token": None, "errors": ["Error code: CS_700"], "state": SocialAuthState.STATE_USER_BLOCKED, }, ) assert ExportsInquiryLog.objects.filter( user__email=self.email).exists() assert (ExportsInquiryLog.objects.get( user__email=self.email).computed_result == RESULT_DENIED) assert len(mail.outbox) == 1 @rule(auth_state=consumes(ConfirmationRedeemedAuthStates)) def register_details_export_temporary_error(self, auth_state): """Complete the register confirmation details page with exports raising a temporary error""" with override_settings(**get_cybersource_test_settings()), patch( "authentication.pipeline.compliance.api.verify_user_with_exports", side_effect=Exception( "register_details_export_temporary_error"), ): assert_api_call( self.client, "psa-register-details", { "flow": auth_state["flow"], "partial_token": auth_state["partial_token"], "password": self.password, "name": "Sally Smith", "legal_address": { "first_name": "Sally", "last_name": "Smith", "street_address": ["Main Street"], "country": "US", "state_or_territory": "US-CO", "city": "Boulder", "postal_code": "02183", }, }, { "flow": auth_state["flow"], "partial_token": None, "errors": [ "Unable to register at this time, please try again later" ], "state": SocialAuthState.STATE_ERROR_TEMPORARY, }, ) assert not ExportsInquiryLog.objects.filter( user__email=self.email).exists() assert len(mail.outbox) == 0 @rule(auth_state=consumes(RegisterExtraDetailsAuthStates)) def register_user_extra_details(self, auth_state): """Complete the user's extra details""" assert_api_call( Client(), "psa-register-extra", { "flow": auth_state["flow"], "partial_token": auth_state["partial_token"], "gender": "f", "birth_year": "2000", "company": "MIT", "job_title": "QA Manager", }, { "flow": auth_state["flow"], "state": SocialAuthState.STATE_SUCCESS, "partial_token": None, }, expect_authenticated=True, )
def test_rule_deprecation_bundle_by_name(): Bundle("k") with pytest.raises(InvalidArgument): rule(target="k")
class TrustchainValidate(RuleBasedStateMachine): NonRevokedAdminUsers = Bundle("admin users") NonRevokedOtherUsers = Bundle("other users") RevokedUsers = Bundle("revoked users") def next_user_id(self): nonlocal name_count name_count += 1 return UserID(f"user{name_count}") def next_device_id(self, user_id=None): nonlocal name_count user_id = user_id or self.next_user_id() name_count += 1 return user_id.to_device_id(DeviceName(f"dev{name_count}")) def new_user_and_device(self, is_admin, certifier_id, certifier_key): device_id = self.next_device_id() local_device = local_device_factory(device_id, org=coolorg) self.local_devices[device_id] = local_device user = UserCertificateContent( author=certifier_id, timestamp=pendulum_now(), user_id=local_device.user_id, human_handle=local_device.human_handle, public_key=local_device.public_key, profile=UserProfile.ADMIN if is_admin else UserProfile.STANDARD, ) self.users_content[device_id.user_id] = user self.users_certifs[device_id.user_id] = user.dump_and_sign( certifier_key) device = DeviceCertificateContent( author=certifier_id, timestamp=pendulum_now(), device_id=local_device.device_id, device_label=local_device.device_label, verify_key=local_device.verify_key, ) self.devices_content[local_device.device_id] = device self.devices_certifs[ local_device.device_id] = device.dump_and_sign(certifier_key) return device_id @initialize(target=NonRevokedAdminUsers) def init(self): caplog.clear() self.users_certifs = {} self.users_content = {} self.revoked_users_certifs = {} self.revoked_users_content = {} self.devices_certifs = {} self.devices_content = {} self.local_devices = {} device_id = self.new_user_and_device( is_admin=True, certifier_id=None, certifier_key=coolorg.root_signing_key) note(f"new device: {device_id}") return device_id.user_id def get_device(self, user_id, device_rand): user_devices = [ device for device_id, device in self.local_devices.items() if device_id.user_id == user_id ] return user_devices[device_rand % len(user_devices)] @rule( target=NonRevokedAdminUsers, author_user=NonRevokedAdminUsers, author_device_rand=st.integers(min_value=0), ) def new_admin_user(self, author_user, author_device_rand): author = self.get_device(author_user, author_device_rand) device_id = self.new_user_and_device( is_admin=True, certifier_id=author.device_id, certifier_key=author.signing_key) note(f"new device: {device_id} (author: {author.device_id})") return device_id.user_id @rule( target=NonRevokedOtherUsers, author_user=NonRevokedAdminUsers, author_device_rand=st.integers(min_value=0), ) def new_non_admin_user(self, author_user, author_device_rand): author = self.get_device(author_user, author_device_rand) device_id = self.new_user_and_device( is_admin=False, certifier_id=author.device_id, certifier_key=author.signing_key) note(f"new device: {device_id} (author: {author.device_id})") return device_id.user_id @precondition(lambda self: len( [d for d in self.local_devices.values() if d.is_admin]) > 1) @rule( target=RevokedUsers, user=st.one_of(consumes(NonRevokedAdminUsers), consumes(NonRevokedOtherUsers)), author_rand=st.integers(min_value=0), ) def revoke_user(self, user, author_rand): possible_authors = [ device for device_id, device in self.local_devices.items() if device_id.user_id != user and device.profile == UserProfile.ADMIN ] author = possible_authors[author_rand % len(possible_authors)] note(f"revoke user: {user} (author: {author.device_id})") revoked_user = RevokedUserCertificateContent( author=author.device_id, timestamp=pendulum_now(), user_id=user) self.revoked_users_content[user] = revoked_user self.revoked_users_certifs[user] = revoked_user.dump_and_sign( author.signing_key) return user @rule( user=st.one_of(NonRevokedAdminUsers, NonRevokedOtherUsers), author_user=NonRevokedAdminUsers, author_device_rand=st.integers(min_value=0), ) def new_device(self, user, author_user, author_device_rand): author = self.get_device(author_user, author_device_rand) device_id = self.next_device_id(user) note(f"new device: {device_id} (author: {author.device_id})") local_device = local_device_factory(device_id, org=coolorg) device = DeviceCertificateContent( author=author.device_id, timestamp=pendulum_now(), device_id=local_device.device_id, device_label=local_device.device_label, verify_key=local_device.verify_key, ) self.devices_content[local_device.device_id] = device self.devices_certifs[ local_device.device_id] = device.dump_and_sign( author.signing_key) @rule(user=st.one_of(NonRevokedAdminUsers, NonRevokedOtherUsers)) def load_trustchain(self, user): ctx = TrustchainContext(coolorg.root_verify_key, 1) user_certif = next( certif for user_id, certif in self.users_certifs.items() if user_id == user) revoked_user_certif = next( (certif for user_id, certif in self.revoked_users_certifs.items() if user_id == user), None, ) devices_certifs = [ certif for device_id, certif in self.devices_certifs.items() if device_id.user_id == user ] user_content, revoked_user_content, devices_contents = ctx.load_user_and_devices( trustchain={ "users": [certif for certif in self.users_certifs.values()], "revoked_users": [certif for certif in self.revoked_users_certifs.values()], "devices": [certif for certif in self.devices_certifs.values()], }, user_certif=user_certif, revoked_user_certif=revoked_user_certif, devices_certifs=devices_certifs, expected_user_id=user, ) expected_user_content = next( content for user_id, content in self.users_content.items() if user_id == user) expected_revoked_user_content = next( (content for user_id, content in self.revoked_users_content.items() if user_id == user), None, ) expected_devices_contents = [ content for device_id, content in self.devices_content.items() if device_id.user_id == user ] assert user_content == expected_user_content assert revoked_user_content == expected_revoked_user_content assert sorted(devices_contents, key=lambda device: device.device_id) == sorted( expected_devices_contents, key=lambda device: device.device_id)
def test_rule_deprecation_targets_and_target(): k, v = Bundle("k"), Bundle("v") with pytest.raises(InvalidArgument): rule(targets=(k,), target=v)
def test_deprecated_target_consumes_bundle(): # It would be nicer to raise this error at runtime, but the internals make # this sadly impractical. Most InvalidDefinition errors happen at, well, # definition-time already anyway, so it's not *worse* than the status quo. with validate_deprecation(): rule(target=consumes(Bundle("b")))
class StatefulTestFileGenerator(RuleBasedStateMachine): model_object = Bundle("model_object") name = Bundle("name") constants = Bundle("constants") fields = Bundle("fields") meta = Bundle("meta") @initialize() def remove_generated_file(self): if os.path.isfile(FILE): os.remove(FILE) @rule(target=name, name=fake_class_name()) def add_name(self, name): assume(not model_exists(name)) return name @rule(target=constants, constants=fake_constants()) def add_constants(self, constants): return constants @rule(target=fields, fields=fake_fields_data()) def add_fields(self, fields): return fields @rule(target=meta, meta=default_meta()) def add_meta(self, meta): return meta @rule( target=model_object, name=consumes(name), constants=constants, fields=fields, meta=meta, ) def add_model_object(self, name, constants, fields, meta): # Remove Duplicates Fields. for field in fields: constants.pop(field, None) try: django_model = get_django_model(name=name, constants=constants, fields=fields, meta=meta) model_object = get_model_object(django_model) except Exception as e: pytest.fail(e) else: return model_object @rule(original=consumes(model_object), tester=consumes(model_object)) def assert_file_generator(self, original, tester): initial_file = [] if os.path.isfile(FILE): event("assert_file_generator: File already exists.") with open(FILE, "r") as f: initial_file = f.read().splitlines() else: event("assert_file_generator: File doesn't exists.") file_generator_instance = FileGenerator(original, tester) # Test File exists. assert os.path.isfile(FILE) with open(FILE, "r") as f: modified_file = f.read().splitlines() # Test File has Header. assert all(modified_file[index] == line for index, line in enumerate(FILE_HEADER.splitlines())) # Test initial data isn't modified. for line_number, line in enumerate(initial_file): assert line == modified_file[line_number] # Test names are corrects and attributes exists. appended_data = modified_file[len(initial_file):] pattern = r"assert (?P<original>\S+) == (?P<tester>\S+), assert_msg(.+)" for line in appended_data: assert_line = re.search(pattern, line) if assert_line: assert_line = { model: { "name": breadcrumb.split(".")[0], "attr": ".".join(breadcrumb.split(".")[1:]), } for model, breadcrumb in assert_line.groupdict().items() } # Test names are corrects. assert assert_line["original"]["name"] == original._meta.name assert assert_line["tester"]["name"] == tester._meta.name # Test attributes compared are the same. assert assert_line["original"]["attr"] == assert_line[ "tester"]["attr"] # Test attributes exists. assert hasattrs(original, assert_line["original"]["attr"]) assert hasattrs(tester, assert_line["tester"]["attr"]) # Try retrieve generated functions. try: generated_functions = file_generator_instance.get_functions() except Exception as e: pytest.fail(e) # Test Module Import. try: module = sys.modules[MODULE] # Test Module has the Generated Class. assert hasattr(module, tester._meta.name) # Test Generated Class has the generated functions. generated_class = getattr(module, tester._meta.name) for generated_function in generated_functions.keys(): assert hasattr(generated_class, generated_function) except KeyError as e: pytest.fail(e)
class HypothesisSpec(RuleBasedStateMachine): def __init__(self): super(HypothesisSpec, self).__init__() self.database = None strategies = Bundle(u'strategy') strategy_tuples = Bundle(u'tuples') objects = Bundle(u'objects') basic_data = Bundle(u'basic') varied_floats = Bundle(u'varied_floats') def teardown(self): self.clear_database() @rule() def clear_database(self): if self.database is not None: self.database.close() self.database = None @rule() def set_database(self): self.teardown() self.database = ExampleDatabase() @rule(strat=strategies, r=integers(), max_shrinks=integers(0, 100)) def find_constant_failure(self, strat, r, max_shrinks): with settings( verbosity=Verbosity.quiet, max_examples=1, min_satisfying_examples=0, database=self.database, max_shrinks=max_shrinks, ): @given(strat) @seed(r) def test(x): assert False try: test() except (AssertionError, FailedHealthCheck): pass @rule(strat=strategies, r=integers(), p=floats(0, 1), max_examples=integers(1, 10), max_shrinks=integers(1, 100)) def find_weird_failure(self, strat, r, max_examples, p, max_shrinks): with settings( verbosity=Verbosity.quiet, max_examples=max_examples, min_satisfying_examples=0, database=self.database, max_shrinks=max_shrinks, ): @given(strat) @seed(r) def test(x): assert Random(hashlib.md5( repr(x).encode(u'utf-8')).digest()).random() <= p try: test() except (AssertionError, FailedHealthCheck): pass @rule(target=strategies, spec=sampled_from(( integers(), booleans(), floats(), complex_numbers(), fractions(), decimals(), text(), binary(), none(), tuples(), ))) def strategy(self, spec): return spec @rule(target=strategies, values=lists(integers() | text(), min_size=1)) def sampled_from_strategy(self, values): return sampled_from(values) @rule(target=strategies, spec=strategy_tuples) def strategy_for_tupes(self, spec): return tuples(*spec) @rule(target=strategies, source=strategies, level=integers(1, 10), mixer=text()) def filtered_strategy(s, source, level, mixer): def is_good(x): return bool( Random( hashlib.md5( (mixer + repr(x)).encode(u'utf-8')).digest()).randint( 0, level)) return source.filter(is_good) @rule(target=strategies, elements=strategies) def list_strategy(self, elements): return lists(elements) @rule(target=strategies, left=strategies, right=strategies) def or_strategy(self, left, right): return left | right @rule(target=varied_floats, source=floats()) def float(self, source): return source @rule(target=varied_floats, source=varied_floats, offset=integers(-100, 100)) def adjust_float(self, source, offset): return int_to_float(clamp(0, float_to_int(source) + offset, 2**64 - 1)) @rule(target=strategies, left=varied_floats, right=varied_floats) def float_range(self, left, right): for f in (math.isnan, math.isinf): for x in (left, right): assume(not f(x)) left, right = sorted((left, right)) assert left <= right return floats(left, right) @rule(target=strategies, source=strategies, result1=strategies, result2=strategies, mixer=text(), p=floats(0, 1)) def flatmapped_strategy(self, source, result1, result2, mixer, p): assume(result1 is not result2) def do_map(value): rep = repr(value) random = Random( hashlib.md5((mixer + rep).encode(u'utf-8')).digest()) if random.random() <= p: return result1 else: return result2 return source.flatmap(do_map) @rule(target=strategies, value=objects) def just_strategy(self, value): return just(value) @rule(target=strategy_tuples, source=strategies) def single_tuple(self, source): return (source, ) @rule(target=strategy_tuples, left=strategy_tuples, right=strategy_tuples) def cat_tuples(self, left, right): return left + right @rule(target=objects, strat=strategies, data=data()) def get_example(self, strat, data): data.draw(strat) @rule(target=strategies, left=integers(), right=integers()) def integer_range(self, left, right): left, right = sorted((left, right)) return integers(left, right) @rule(strat=strategies) def repr_is_good(self, strat): assert u' at 0x' not in repr(strat)
class GraphCompare(RuleBasedStateMachine): def __init__(self): super().__init__() # stores the corresponding node/tensor v1, v2, ... as they are # created via the unit test (through `create_node` or `fuse_nodes`) # `Node` is the naive implementation of `Tensor` that we are checking # against self.node_list = [] # type: List[Tuple[Node, Tensor]] self.str_to_tensor_op = {"add": add, "multiply": multiply} self.str_to_node_op = {"add": _add, "multiply": _multiply} self.raised = False nodes = Bundle("nodes") @rule(target=nodes, value=st.floats(-10, 10), constant=st.booleans()) def create_node(self, value, constant): n = Node(value, constant=constant) t = Tensor(value, constant=constant) self.node_list.append((n, t)) return n, t @rule( target=nodes, a=nodes, b=nodes, op=st.sampled_from(["add", "multiply"]), constant=st.booleans(), ) def fuse_nodes(self, a, b, op, constant): """ Combine any pair of nodes (tensors) using either addition or multiplication, producing a new node (tensor)""" n_a, t_a = a n_b, t_b = b n_op = self.str_to_node_op[op] t_op = self.str_to_tensor_op[op] out = (n_op(n_a, n_b, constant=constant), t_op(t_a, t_b, constant=constant)) self.node_list.append(out) return out @rule(items=nodes, clear_graph=st.booleans()) def null_gradients(self, items, clear_graph): """ Invoke `null_gradients` on the computational graph (naive and mygrad), with `clear_graph=True` specified optionally. """ n, t = items n.null_gradients(clear_graph=clear_graph) t.null_gradients(clear_graph=clear_graph) @rule(items=nodes) def clear_graph(self, items): """ Invoke `clear_graph` on the computational graph (naive and mygrad) """ n, t = items n.clear_graph() t.clear_graph() @rule(items=nodes, grad=st.floats(-10, 10)) def backprop(self, items, grad): """ Invoke `backward(grad)` on the computational graph (naive and mygrad) from a randomly-selected node in the computational graph and using a randomly-generated gradient value. An exception should be raised if `clear_graph` is invoked anywhere prior to the invoking node. """ n, t = items n.null_gradients(clear_graph=False) t.null_gradients(clear_graph=False) try: n.backward(grad, terminal_node=True) except Exception: with raises(Exception): t.backward(grad) self.raised = True else: t.backward(grad) @precondition(lambda self: not self.raised) @invariant() def all_agree(self): """ Ensure that all corresponding nodes/tensors have matching data and gradients across the respective graphs. """ for num, (n, t) in enumerate(self.node_list): assert bool(n._ops) is bool(t._ops), _node_ID_str(num) assert_equal(n.data, t.data, err_msg=_node_ID_str(num)) if n.grad is None or t.grad is None: assert n.grad is t.grad, _node_ID_str(num) else: assert_allclose( actual=t.grad, desired=n.grad, atol=1e-5, rtol=1e-5, err_msg=_node_ID_str(num), ) assert not t._accum_ops, _node_ID_str(num)
class SetModel(RuleBasedStateMachine): intsets = Bundle('IntSets') values = Bundle('values') @rule(target=values, i=integers_in_range) def int_value(self, i): return i @rule(target=values, i=integers_in_range, imp=intsets) def endpoint_value(self, i, imp): if len(imp[0]) > 0: return imp[0][-1] else: return i @rule(target=values, i=integers_in_range, imp=intsets) def startpoint_value(self, i, imp): if len(imp[0]) > 0: return imp[0][0] else: return i @rule(target=intsets, bounds=short_intervals) def build_interval(self, bounds): return (IntSet.interval(*bounds), list(range(*bounds))) @rule(target=intsets, v=values) def single_value(self, v): return (IntSet.single(v), [v]) @rule(target=intsets, v=values) def adjacent_values(self, v): assume(v + 1 <= 2 ** 64) return (IntSet.interval(v, v + 2), [v, v + 1]) @rule(target=intsets, v=values) def three_adjacent_values(self, v): assume(v + 2 <= 2 ** 64) return (IntSet.interval(v, v + 3), [v, v + 1, v + 2]) @rule(target=intsets, v=values) def three_adjacent_values_with_hole(self, v): assume(v + 2 <= 2 ** 64) return (IntSet.single(v) | IntSet.single(v + 2), [v, v + 2]) @rule(target=intsets, x=intsets, y=intsets) def union(self, x, y): return (x[0] | y[0], sorted(set(x[1] + y[1]))) @rule(target=intsets, x=intsets, y=intsets) def intersect(self, x, y): return (x[0] & y[0], sorted(set(x[1]) & set(y[1]))) @rule(target=intsets, x=intsets, y=intsets) def subtract(self, x, y): return (x[0] - y[0], sorted(set(x[1]) - set(y[1]))) @rule(target=intsets, x=intsets, i=values) def insert(self, x, i): return (x[0].insert(i), sorted(set(x[1] + [i]))) @rule(target=intsets, x=intsets, i=values) def discard(self, x, i): return (x[0].discard(i), sorted(set(x[1]) - set([i]))) @rule(target=intsets, source=intsets, bounds=intervals) def restrict(self, source, bounds): return ( source[0].restrict(*bounds), [x for x in source[1] if bounds[0] <= x < bounds[1]]) @rule(target=intsets, x=intsets) def peel_left(self, x): if len(x[0]) == 0: return x return self.restrict(x, (x[0][0], x[0][-1] + 1)) @rule(target=intsets, x=intsets) def peel_right(self, x): if len(x[0]) == 0: return x return self.restrict(x, (x[0][0], x[0][-1])) @rule(x=intsets, y=intsets) def validate_order(self, x, y): assert (x[0] <= y[0]) == (x[1] <= y[1]) @rule(x=intsets, y=intsets) def validate_equality(self, x, y): assert (x[0] == y[0]) == (x[1] == y[1]) @rule(source=intsets) def validate(self, source): assert list(source[0]) == source[1] assert len(source[0]) == len(source[1]) for i in range(-len(source[0]), len(source[0])): assert source[0][i] == source[1][i] if len(source[0]) > 0: for v in source[1]: assert source[0][0] <= v <= source[0][-1]
class QCircuitMachine(RuleBasedStateMachine): """Build a Hypothesis rule based state machine for constructing, transpiling and simulating a series of random QuantumCircuits. Build circuits with up to QISKIT_RANDOM_QUBITS qubits, apply a random selection of gates from qiskit.extensions.standard with randomly selected qargs, cargs, and parameters. At random intervals, transpile the circuit for a random backend with a random optimization level and simulate both the initial and the transpiled circuits to verify that their counts are the same. """ qubits = Bundle('qubits') clbits = Bundle('clbits') backend = Aer.get_backend('qasm_simulator') max_qubits = int(backend.configuration().n_qubits / 2) def __init__(self): super().__init__() self.qc = QuantumCircuit() @precondition(lambda self: len(self.qc.qubits) < self.max_qubits) @rule(target=qubits, n=st.integers(min_value=1, max_value=max_qubits)) def add_qreg(self, n): """Adds a new variable sized qreg to the circuit, up to max_qubits.""" n = min(n, self.max_qubits - len(self.qc.qubits)) qreg = QuantumRegister(n) self.qc.add_register(qreg) return multiple(*list(qreg)) @rule(target=clbits, n=st.integers(1, 5)) def add_creg(self, n): """Add a new variable sized creg to the circuit.""" creg = ClassicalRegister(n) self.qc.add_register(creg) return multiple(*list(creg)) # Gates of various shapes @rule(gate=st.sampled_from(oneQ_gates), qarg=qubits) def add_1q_gate(self, gate, qarg): """Append a random 1q gate on a random qubit.""" self.qc.append(gate(), [qarg], []) @rule(gate=st.sampled_from(twoQ_gates), qargs=st.lists(qubits, max_size=2, min_size=2, unique=True)) def add_2q_gate(self, gate, qargs): """Append a random 2q gate across two random qubits.""" self.qc.append(gate(), qargs) @rule(gate=st.sampled_from(threeQ_gates), qargs=st.lists(qubits, max_size=3, min_size=3, unique=True)) def add_3q_gate(self, gate, qargs): """Append a random 3q gate across three random qubits.""" self.qc.append(gate(), qargs) @rule(gate=st.sampled_from(oneQ_oneP_gates), qarg=qubits, param=st.floats(allow_nan=False, allow_infinity=False)) def add_1q1p_gate(self, gate, qarg, param): """Append a random 1q gate with 1 random float parameter.""" self.qc.append(gate(param), [qarg]) @rule(gate=st.sampled_from(oneQ_twoP_gates), qarg=qubits, params=st.lists( st.floats(allow_nan=False, allow_infinity=False), min_size=2, max_size=2)) def add_1q2p_gate(self, gate, qarg, params): """Append a random 1q gate with 2 random float parameters.""" self.qc.append(gate(*params), [qarg]) @rule(gate=st.sampled_from(oneQ_threeP_gates), qarg=qubits, params=st.lists( st.floats(allow_nan=False, allow_infinity=False), min_size=3, max_size=3)) def add_1q3p_gate(self, gate, qarg, params): """Append a random 1q gate with 3 random float parameters.""" self.qc.append(gate(*params), [qarg]) @rule(gate=st.sampled_from(twoQ_oneP_gates), qargs=st.lists(qubits, max_size=2, min_size=2, unique=True), param=st.floats(allow_nan=False, allow_infinity=False)) def add_2q1p_gate(self, gate, qargs, param): """Append a random 2q gate with 1 random float parameter.""" self.qc.append(gate(param), qargs) @rule(gate=st.sampled_from(twoQ_threeP_gates), qargs=st.lists(qubits, max_size=2, min_size=2, unique=True), params=st.lists( st.floats(allow_nan=False, allow_infinity=False), min_size=3, max_size=3)) def add_2q3p_gate(self, gate, qargs, params): """Append a random 2q gate with 3 random float parameters.""" self.qc.append(gate(*params), qargs) @rule(gate=st.sampled_from(oneQ_oneC_gates), qarg=qubits, carg=clbits) def add_1q1c_gate(self, gate, qarg, carg): """Append a random 1q, 1c gate.""" self.qc.append(gate(), [qarg], [carg]) @rule(gate=st.sampled_from(variadic_gates), qargs=st.lists(qubits, min_size=1, unique=True)) def add_variQ_gate(self, gate, qargs): """Append a gate with a variable number of qargs.""" self.qc.append(gate(len(qargs)), qargs) @precondition(lambda self: len(self.qc.data) > 0) @rule(carg=clbits, data=st.data()) def add_c_if_last_gate(self, carg, data): """Modify the last gate to be conditional on a classical register.""" creg = carg.register val = data.draw(st.integers(min_value=0, max_value=2**len(creg)-1)) last_gate = self.qc.data[-1] # Work around for https://github.com/Qiskit/qiskit-terra/issues/2567 assume(not isinstance(last_gate[0], Measure) or creg != last_gate[2][0].register) last_gate[0].c_if(creg, val) # Properties to check @invariant() def qasm(self): """After each circuit operation, it should be possible to build QASM.""" self.qc.qasm() @precondition(lambda self: any(isinstance(d[0], Measure) for d in self.qc.data)) @rule( backend=st.one_of( st.none(), st.sampled_from(mock_backends)), opt_level=st.one_of( st.none(), st.integers(min_value=0, max_value=3))) def equivalent_transpile(self, backend, opt_level): """Simulate, transpile and simulate the present circuit. Verify that the counts are not significantly different before and after transpilation. """ assume(backend is None or backend.configuration().n_qubits >= len(self.qc.qubits)) shots = 4096 aer_qasm_simulator = self.backend aer_counts = execute(self.qc, backend=self.backend, shots=shots).result().get_counts() try: xpiled_qc = transpile(self.qc, backend=backend, optimization_level=opt_level) except Exception as e: failed_qasm = 'Exception caught during transpilation of circuit: \n{}'.format( self.qc.qasm()) raise RuntimeError(failed_qasm) from e xpiled_aer_counts = execute(xpiled_qc, backend=self.backend, shots=shots).result().get_counts() count_differences = dicts_almost_equal(aer_counts, xpiled_aer_counts, 0.05 * shots) assert count_differences == '', 'Counts not equivalent: {}\nFailing QASM: \n{}'.format( count_differences, self.qc.qasm())
class HypothesisSpec(RuleBasedStateMachine): def __init__(self): super().__init__() self.database = None strategies = Bundle("strategy") strategy_tuples = Bundle("tuples") objects = Bundle("objects") basic_data = Bundle("basic") varied_floats = Bundle("varied_floats") def teardown(self): self.clear_database() @rule() def clear_database(self): if self.database is not None: self.database = None @rule() def set_database(self): self.teardown() self.database = ExampleDatabase() @rule( target=strategies, spec=sampled_from(( integers(), booleans(), floats(), complex_numbers(), fractions(), decimals(), text(), binary(), none(), tuples(), )), ) def strategy(self, spec): return spec @rule(target=strategies, values=lists(integers() | text(), min_size=1)) def sampled_from_strategy(self, values): return sampled_from(values) @rule(target=strategies, spec=strategy_tuples) def strategy_for_tupes(self, spec): return tuples(*spec) @rule(target=strategies, source=strategies, level=integers(1, 10), mixer=text()) def filtered_strategy(s, source, level, mixer): def is_good(x): seed = hashlib.sha384((mixer + repr(x)).encode()).digest() return bool(Random(seed).randint(0, level)) return source.filter(is_good) @rule(target=strategies, elements=strategies) def list_strategy(self, elements): return lists(elements) @rule(target=strategies, left=strategies, right=strategies) def or_strategy(self, left, right): return left | right @rule(target=varied_floats, source=floats()) def float(self, source): return source @rule(target=varied_floats, source=varied_floats, offset=integers(-100, 100)) def adjust_float(self, source, offset): return int_to_float(clamp(0, float_to_int(source) + offset, 2**64 - 1)) @rule(target=strategies, left=varied_floats, right=varied_floats) def float_range(self, left, right): assume(math.isfinite(left) and math.isfinite(right)) left, right = sorted((left, right)) assert left <= right # exclude deprecated case where left = 0.0 and right = -0.0 assume(left or right or not (is_negative(right) and not is_negative(left))) return floats(left, right) @rule( target=strategies, source=strategies, result1=strategies, result2=strategies, mixer=text(), p=floats(0, 1), ) def flatmapped_strategy(self, source, result1, result2, mixer, p): assume(result1 is not result2) def do_map(value): rep = repr(value) random = Random(hashlib.sha384((mixer + rep).encode()).digest()) if random.random() <= p: return result1 else: return result2 return source.flatmap(do_map) @rule(target=strategies, value=objects) def just_strategy(self, value): return just(value) @rule(target=strategy_tuples, source=strategies) def single_tuple(self, source): return (source, ) @rule(target=strategy_tuples, left=strategy_tuples, right=strategy_tuples) def cat_tuples(self, left, right): return left + right @rule(target=objects, strat=strategies, data=data()) def get_example(self, strat, data): data.draw(strat) @rule(target=strategies, left=integers(), right=integers()) def integer_range(self, left, right): left, right = sorted((left, right)) return integers(left, right) @rule(strat=strategies) def repr_is_good(self, strat): assert " at 0x" not in repr(strat)
class FolderOperationsStateMachine(RuleBasedStateMachine): Files = Bundle("file") Folders = Bundle("folder") # Moving mountpoint NonRootFolder = Folders.filter(lambda x: not x.is_workspace()) @initialize(target=Folders) def init(self): nonlocal tentative tentative += 1 caplog.clear() async def _bootstrap(user_fs, mountpoint_manager): wid = await user_fs.workspace_create("w") await mountpoint_manager.mount_workspace(wid) self.mountpoint_service = mountpoint_service_factory(_bootstrap) self.folder_oracle = Path(tmpdir / f"oracle-test-{tentative}") self.folder_oracle.mkdir() oracle_root = self.folder_oracle / "root" oracle_root.mkdir() self.folder_oracle.chmod( 0o500) # Root oracle can no longer be removed this way (oracle_root / "w").mkdir() oracle_root.chmod(0o500) # Also protect workspace from deletion return PathElement(f"/w", self.mountpoint_service.base_mountpoint, oracle_root) def teardown(self): if hasattr(self, "mountpoint_service"): self.mountpoint_service.stop() @rule(target=Files, parent=Folders, name=st_entry_name) def touch(self, parent, name): path = parent / name expected_exc = None try: path.to_oracle().touch(exist_ok=False) except OSError as exc: expected_exc = exc with expect_raises(expected_exc): path.to_parsec().touch(exist_ok=False) return path @rule(target=Folders, parent=Folders, name=st_entry_name) def mkdir(self, parent, name): path = parent / name expected_exc = None try: path.to_oracle().mkdir(exist_ok=False) except OSError as exc: expected_exc = exc with expect_raises(expected_exc): path.to_parsec().mkdir(exist_ok=False) return path @rule(path=Files) def unlink(self, path): expected_exc = None try: path.to_oracle().unlink() except OSError as exc: expected_exc = exc with expect_raises(expected_exc): path.to_parsec().unlink() @rule(path=Files, length=st.integers(min_value=0, max_value=16)) def resize(self, path, length): expected_exc = None try: os.truncate(path.to_oracle(), length) except OSError as exc: expected_exc = exc with expect_raises(expected_exc): os.truncate(path.to_parsec(), length) @rule(path=NonRootFolder) def rmdir(self, path): expected_exc = None try: path.to_oracle().rmdir() except OSError as exc: expected_exc = exc with expect_raises(expected_exc): path.to_parsec().rmdir() def _move(self, src, dst_parent, dst_name): dst = dst_parent / dst_name expected_exc = None try: oracle_rename(src.to_oracle(), dst.to_oracle()) except OSError as exc: expected_exc = exc with expect_raises(expected_exc): src.to_parsec().rename(str(dst.to_parsec())) return dst @rule(target=Files, src=Files, dst_parent=Folders, dst_name=st_entry_name) def move_file(self, src, dst_parent, dst_name): return self._move(src, dst_parent, dst_name) @rule(target=Folders, src=NonRootFolder, dst_parent=Folders, dst_name=st_entry_name) def move_folder(self, src, dst_parent, dst_name): return self._move(src, dst_parent, dst_name) @rule(path=Folders) def iterdir(self, path): expected_exc = None try: expected_children = { x.name for x in path.to_oracle().iterdir() } except OSError as exc: expected_exc = exc with expect_raises(expected_exc): children = {x.name for x in path.to_parsec().iterdir()} if not expected_exc: assert children == expected_children
return Counter([type(event) for event in events]) == Counter(expected_types) def transferred_amount(state): return 0 if not state.balance_proof else state.balance_proof.transferred_amount # use of hypothesis.stateful.multiple() breaks the failed-example code # generation at the moment, this function is a temporary workaround def unwrap_multiple(multiple_results): values = multiple_results.values return values[0] if len(values) == 1 else values partners = Bundle("partners") # shared bundle of ChainStateStateMachine and all mixin classes class ChainStateStateMachine(RuleBasedStateMachine): def __init__(self, address=None): self.address = address or factories.make_address() self.replay_path = False self.address_to_channel = dict() self.address_to_privkey = dict() self.initial_number_of_channels = 1 self.our_previous_deposit = defaultdict(int) self.partner_previous_deposit = defaultdict(int) self.our_previous_transferred = defaultdict(int) self.partner_previous_transferred = defaultdict(int)
class IntAdder(RuleBasedStateMachine): pass IntAdder.define_rule(targets=(u"ints", ), function=lambda self, x: x, arguments={u"x": integers()}) IntAdder.define_rule( targets=(u"ints", ), function=lambda self, x, y: x, arguments={ u"x": integers(), u"y": Bundle(u"ints") }, ) TestDynamicMachine = DynamicMachine.TestCase TestIntAdder = IntAdder.TestCase TestPrecondition = PreconditionMachine.TestCase for test_case in (TestDynamicMachine, TestIntAdder, TestPrecondition): test_case.settings = Settings(test_case.settings, max_examples=10) def test_picks_up_settings_at_first_use_of_testcase(): assert TestDynamicMachine.settings.max_examples == 10
class MediatorMixin: def __init__(self): super().__init__() self.partner_to_balance_proof_data = dict() self.secrethash_to_secret = dict() self.waiting_for_unlock = dict() self.initial_number_of_channels = 2 def _get_balance_proof_data(self, partner): if partner not in self.partner_to_balance_proof_data: partner_channel = self.address_to_channel[partner] self.partner_to_balance_proof_data[partner] = BalanceProofData( canonical_identifier=partner_channel.canonical_identifier) return self.partner_to_balance_proof_data[partner] def _update_balance_proof_data(self, partner, amount, expiration, secret): expected = self._get_balance_proof_data(partner) lock = HashTimeLockState(amount=amount, expiration=expiration, secrethash=sha256_secrethash(secret)) expected.update(amount, lock) return expected init_mediators = Bundle("init_mediators") secret_requests = Bundle("secret_requests") unlocks = Bundle("unlocks") def _new_mediator_transfer(self, initiator_address, target_address, payment_id, amount, secret) -> LockedTransferSignedState: initiator_pkey = self.address_to_privkey[initiator_address] balance_proof_data = self._update_balance_proof_data( initiator_address, amount, self.block_number + 10, secret) self.secrethash_to_secret[sha256_secrethash(secret)] = secret return factories.create( factories.LockedTransferSignedStateProperties( **balance_proof_data.properties.__dict__, amount=amount, expiration=self.block_number + 10, payment_identifier=payment_id, secret=secret, initiator=initiator_address, target=target_address, token=self.token_id, sender=initiator_address, recipient=self.address, pkey=initiator_pkey, message_identifier=1, )) def _action_init_mediator( self, transfer: LockedTransferSignedState) -> ActionInitMediator: initiator_channel = self.address_to_channel[transfer.initiator] target_channel = self.address_to_channel[transfer.target] return ActionInitMediator( route_states=[factories.make_route_from_channel(target_channel)], from_hop=factories.make_hop_to_channel(initiator_channel), from_transfer=transfer, balance_proof=transfer.balance_proof, sender=transfer.balance_proof.sender, ) @rule( target=init_mediators, initiator_address=partners, target_address=partners, payment_id=payment_id(), # pylint: disable=no-value-for-parameter amount=integers(min_value=1, max_value=100), secret=secret(), # pylint: disable=no-value-for-parameter ) def valid_init_mediator(self, initiator_address, target_address, payment_id, amount, secret): assume(initiator_address != target_address) transfer = self._new_mediator_transfer(initiator_address, target_address, payment_id, amount, secret) action = self._action_init_mediator(transfer) result = node.state_transition(self.chain_state, action) assert event_types_match(result.events, SendProcessed, SendLockedTransfer) return action @rule(target=secret_requests, previous_action=consumes(init_mediators)) def valid_receive_secret_reveal(self, previous_action): secret = self.secrethash_to_secret[ previous_action.from_transfer.lock.secrethash] sender = previous_action.from_transfer.target recipient = previous_action.from_transfer.initiator action = ReceiveSecretReveal(secret=secret, sender=sender) result = node.state_transition(self.chain_state, action) expiration = previous_action.from_transfer.lock.expiration in_time = self.block_number < expiration - DEFAULT_NUMBER_OF_BLOCK_CONFIRMATIONS still_waiting = self.block_number < expiration + DEFAULT_WAIT_BEFORE_LOCK_REMOVAL if in_time and self.channel_opened(sender) and self.channel_opened( recipient): assert event_types_match(result.events, SendSecretReveal, SendBalanceProof, EventUnlockSuccess) self.event("Unlock successful.") self.waiting_for_unlock[secret] = recipient elif still_waiting and self.channel_opened(recipient): assert event_types_match(result.events, SendSecretReveal) self.event("Unlock failed, secret revealed too late.") else: assert not result.events self.event( "ReceiveSecretRevealed after removal of lock - dropped.") return action @rule(previous_action=secret_requests) def replay_receive_secret_reveal(self, previous_action): result = node.state_transition(self.chain_state, previous_action) assert not result.events # pylint: disable=no-value-for-parameter @rule(previous_action=secret_requests, invalid_sender=address()) # pylint: enable=no-value-for-parameter def replay_receive_secret_reveal_scrambled_sender(self, previous_action, invalid_sender): action = ReceiveSecretReveal(previous_action.secret, invalid_sender) result = node.state_transition(self.chain_state, action) assert not result.events # pylint: disable=no-value-for-parameter @rule(previous_action=init_mediators, secret=secret()) # pylint: enable=no-value-for-parameter def wrong_secret_receive_secret_reveal(self, previous_action, secret): sender = previous_action.from_transfer.target action = ReceiveSecretReveal(secret, sender) result = node.state_transition(self.chain_state, action) assert not result.events # pylint: disable=no-value-for-parameter @rule(target=secret_requests, previous_action=consumes(init_mediators), invalid_sender=address()) # pylint: enable=no-value-for-parameter def wrong_address_receive_secret_reveal(self, previous_action, invalid_sender): secret = self.secrethash_to_secret[ previous_action.from_transfer.lock.secrethash] invalid_action = ReceiveSecretReveal(secret, invalid_sender) result = node.state_transition(self.chain_state, invalid_action) assert not result.events valid_sender = previous_action.from_transfer.target valid_action = ReceiveSecretReveal(secret, valid_sender) return valid_action
def test_rule_deprecation_targets_and_target(): k, v = Bundle("k"), Bundle("v") rule(targets=(k, ), target=v)
class PVectorBuilder(RuleBasedStateMachine): """ Build a list and matching pvector step-by-step. In each step in the state machine we do same operation on a list and on a pvector, and then when we're done we compare the two. """ sequences = Bundle("sequences") @rule(target=sequences, start=PVectorAndLists) def initial_value(self, start): """ Some initial values generated by a hypothesis strategy. """ return start @rule(target=sequences, former=sequences) @verify_inputs_unmodified def append(self, former): """ Append an item to the pair of sequences. """ l, pv = former obj = TestObject() l2 = l[:] l2.append(obj) return l2, pv.append(obj) @rule(target=sequences, start=sequences, end=sequences) @verify_inputs_unmodified def extend(self, start, end): """ Extend a pair of sequences with another pair of sequences. """ l, pv = start l2, pv2 = end # compare() has O(N**2) behavior, so don't want too-large lists: assume(len(l) + len(l2) < 50) l3 = l[:] l3.extend(l2) return l3, pv.extend(pv2) @rule(target=sequences, former=sequences, choice=st.choices()) @verify_inputs_unmodified def remove(self, former, choice): """ Remove an item from the sequences. """ l, pv = former assume(l) l2 = l[:] i = choice(range(len(l))) del l2[i] return l2, pv.delete(i) @rule(target=sequences, former=sequences, choice=st.choices()) @verify_inputs_unmodified def set(self, former, choice): """ Overwrite an item in the sequence. """ l, pv = former assume(l) l2 = l[:] i = choice(range(len(l))) obj = TestObject() l2[i] = obj return l2, pv.set(i, obj) @rule(target=sequences, former=sequences, choice=st.choices()) @verify_inputs_unmodified def transform_set(self, former, choice): """ Transform the sequence by setting value. """ l, pv = former assume(l) l2 = l[:] i = choice(range(len(l))) obj = TestObject() l2[i] = obj return l2, pv.transform([i], obj) @rule(target=sequences, former=sequences, choice=st.choices()) @verify_inputs_unmodified def transform_discard(self, former, choice): """ Transform the sequence by discarding a value. """ l, pv = former assume(l) l2 = l[:] i = choice(range(len(l))) del l2[i] return l2, pv.transform([i], discard) @rule(target=sequences, former=sequences, choice=st.choices()) @verify_inputs_unmodified def subset(self, former, choice): """ A subset of the previous sequence. """ l, pv = former assume(l) i = choice(range(len(l))) j = choice(range(len(l))) return l[i:j], pv[i:j] @rule(pair=sequences) @verify_inputs_unmodified def compare(self, pair): """ The list and pvector must match. """ l, pv = pair # compare() has O(N**2) behavior, so don't want too-large lists: assume(len(l) < 50) assert_equal(l, pv)
def test_rule_deprecation_bundle_by_name(): Bundle("k") rule(target="k")
class NonTerminalMachine(RuleBasedStateMachine): @rule(value=Bundle(u'hi')) def bye(self, hi): pass
def test_rule_non_bundle_target_oneof(): k, v = Bundle("k"), Bundle("v") pattern = r".+ `one_of(a, b)` or `a | b` .+" with pytest.raises(InvalidArgument, match=pattern): rule(target=k | v)
DynamicMachine.define_rule(targets=(), function=lambda self: 1, arguments={}) class IntAdder(RuleBasedStateMachine): pass IntAdder.define_rule(targets=(u'ints', ), function=lambda self, x: x, arguments={u'x': integers()}) IntAdder.define_rule(targets=(u'ints', ), function=lambda self, x, y: x, arguments={ u'x': integers(), u'y': Bundle(u'ints'), }) class ChoosingMachine(GenericStateMachine): def steps(self): return choices() def execute_step(self, choices): choices([1, 2, 3]) with Settings(max_examples=10): TestChoosingMachine = ChoosingMachine.TestCase TestGoodSets = GoodSet.TestCase TestGivenLike = GivenLikeStateMachine.TestCase
class ListSM(RuleBasedStateMachine): def __init__(self): self._var = ffi.new('struct list*') self._model = deque() self._model_contents = deque() super().__init__() first = alias_property() last = alias_property() def teardown(self): lib.list_clear(self._var) class Iterator(typing.Iterator[None]): cur = alias_property() prev = alias_property() def _make_iter(self, reverse: bool = False) -> ListSM.Iterator: var = ffi.new('struct list_it*') var.cur = self.last if reverse else self.first def it(): while var.cur != ffi.NULL: yield # yield None because we're mutable lib.list_step(var) it = it() class _It(ListSM.Iterator): _var = var def __next__(self) -> None: return next(it) return _It() def __iter__(self) -> ListSM.Iterator: return self._make_iter() def __reversed__(self) -> ListSM.Iterator: return self._make_iter(reverse=True) nodes = Bundle('Nodes') @rule(new_value=elements(), target=nodes) def insert_front(self, new_value): self._model_contents.appendleft(new_value) lib.list_insert_front(self._var, new_value) new_node = self.first assert new_node.value == new_value self._model.appendleft(new_node) return new_node @rule(new_value=elements(), target=nodes) def insert_back(self, new_value): self._model_contents.append(new_value) lib.list_insert_back(self._var, new_value) new_node = self.last assert new_node.value == new_value self._model.append(new_node) return new_node @rule(nodes=strategies.frozensets(consumes(nodes)), reverse=strategies.booleans()) def remove_thru_iter(self, nodes, reverse): it = reversed(self) if reverse else iter(self) for _ in it: if it.cur in nodes: lib.list_remove(self._var, it._var) for n in nodes: i = self._model.index(n) del self._model_contents[i] del self._model[i] @invariant() def nodes_as_model(self): it = iter(self) nodes = [it.cur for _ in it] assert nodes == list(self._model) @invariant() def contents_as_model(self): it = iter(self) contents = [it.cur.value for _ in it] assert contents == list(self._model_contents)