Example #1
0
class MachineWithConsumingRule(RuleBasedStateMachine):
    b1 = Bundle("b1")
    b2 = Bundle("b2")

    def __init__(self):
        self.created_counter = 0
        self.consumed_counter = 0
        super(MachineWithConsumingRule, self).__init__()

    @invariant()
    def bundle_length(self):
        assert len(
            self.bundle("b1")) == self.created_counter - self.consumed_counter

    @rule(target=b1)
    def populate_b1(self):
        self.created_counter += 1
        return self.created_counter

    @rule(target=b2, consumed=consumes(b1))
    def depopulate_b1(self, consumed):
        self.consumed_counter += 1
        return consumed

    @rule(consumed=lists(consumes(b1)))
    def depopulate_b1_multiple(self, consumed):
        self.consumed_counter += len(consumed)

    @rule(value1=b1, value2=b2)
    def check(self, value1, value2):
        assert value1 != value2
Example #2
0
class OnChainMixin:

    @rule(number=integers(min_value=1, max_value=50))
    def new_blocks(self, number):
        events = list()

        for _ in range(number):
            block_state_change = Block(
                block_number=self.block_number + 1,
                gas_limit=1,
                block_hash=factories.make_keccak_hash(),
            )
            result = node.state_transition(self.chain_state, block_state_change)
            events.extend(result.events)

            self.block_number += 1

    @rule(target=partners)
    def open_channel(self):
        return self.new_channel_with_transaction()

    @rule(partner=consumes(partners))
    def settle_channel(self, partner):
        channel = self.address_to_channel[partner]

        channel_settled_state_change = ContractReceiveChannelSettled(
            transaction_hash=factories.make_transaction_hash(),
            token_network_identifier=channel.token_network_identifier,
            channel_identifier=channel.identifier,
            block_number=self.block_number + 1,
        )

        node.state_transition(self.chain_state, channel_settled_state_change)
class StatefulDictStateMachine(RuleBasedStateMachine):
    def __init__(self):
        super().__init__()
        self.tree = TreeDict()
        self.in_dict = {}

    inserted_keys = Bundle("inserted")
    deleted_keys = Bundle("deleted_keys")

    @rule(target=inserted_keys, key=some.integers(), v=some.text())
    def insert(self, key, v):
        self.tree[key] = v
        self.in_dict[key] = v
        return key

    @rule(key=inserted_keys)
    def search(self, key):
        assert self.tree[key] == self.in_dict[key]

    @rule(key=consumes(inserted_keys))
    def delete(self, key):
        assume(key not in self.in_dict)
        del self.tree[key]
        del self.in_dict[key]

    @rule(key=some.integers())
    def search_non_existing(self, key):
        assume(key not in self.in_dict)
        with pytest.raises(KeyError):  # type: ignore
            self.tree[key]
class AutoFSMTest(RuleBasedStateMachine):
    def __init__(self):
        super().__init__()
        self.auto = TransmissionSystem()

    rpms = Bundle('rpms')
    rpm_sets = Bundle('rpm_sets')

    @rule(target=rpms, rpm=st.integers(min_value=0))
    def add_rpm(self, rpm):
        return rpm

    @rule(target=rpm_sets, rpms=st.lists(st.integers(min_value=0)))
    def add_rpms(self, rpms):
        return rpms

    ## These methods exercise the step and run methods of
    ## TransmissionSystem, as possible intervening actions between
    ## test assertions
    @rule(rpm=consumes(rpms))
    def step(self, rpm):
        self.auto.step(rpm)

    @rule(rpms=consumes(rpm_sets))
    def run(self, rpms):
        self.auto.run(rpms)

    # These are the test methods that assert facts about the state machine
    @invariant()
    def state_is_always_a_gear_state(self):
        assert isinstance(self.auto.state, GearState)

    @precondition(lambda self: isinstance(self.auto.state, Neutral))
    @rule(rpm=consumes(rpms))
    def step_from_neutral_must_be_neutral_or_first(self, rpm):
        """Given Neutral state, then next state must be Neutral or FirstGear"""
        self.auto.step(rpm)
        state = self.auto.state
        assert isinstance(state, Neutral) or isinstance(state, FirstGear)
class ManagerTest(RuleBasedStateMachine):
    boxes = Bundle("boxes")

    def __init__(self):
        super().__init__()
        self.manager = box_manager.BoxManager()
        self.current_boxes = set()

    @rule(target=boxes, new_boxes=BOXES_ST)
    def add_boxes(self, new_boxes):
        for b in new_boxes:
            self.manager.register(b)
            self.current_boxes.add(b)
            assert b.manager is not None
        return multiple(*new_boxes)

    @rule(to_remove=consumes(boxes))
    def remove_box(self, to_remove):
        self.manager.remove(to_remove)
        assert to_remove in self.current_boxes
        self.current_boxes.remove(to_remove)
        assert to_remove.manager is None
        if to_remove.stationary:
            assert not self.manager.stationary_cache_valid

    @rule(
        box=boxes,
        new_x=st.floats(min_value=-100.0, max_value=100.0),
        new_y=st.floats(min_value=-100.0, max_value=100.0),
    )
    def move_box(self, box, new_x, new_y):
        box.move(new_x, new_y)
        assert box.manager is not None
        if box.stationary:
            assert not self.manager.stationary_cache_valid

    @invariant()
    def check_box_collisions(self):
        all_boxes = list(self.current_boxes)
        good = list(base_algorithms.check_deduplicated(all_boxes))
        # We run this one twice, so that the first one has a chance to cache the stationary boxes.
        # Then we check all 3 of them.
        unknown_precached = list(self.manager.yield_collisions())
        unknown_cached = list(self.manager.yield_collisions())
        good_s = make_deterministic_set(good)
        unknown_precached_s = make_deterministic_set(unknown_precached)
        unknown_cached_s = make_deterministic_set(unknown_cached)
        assert len(good_s.intersection(unknown_precached_s)) == len(good_s)
        assert len(good_s.intersection(unknown_cached_s)) == len(good_s)
class GitlabStateful(RuleBasedStateMachine):
    def __init__(self):
        super().__init__()

        self._gitlab = GitlabAPI()
        self._gitlab.prepare()

    created_users = Bundle("users")

    @rule(target=created_users, user=users())
    def create_new_user(self, user: User):
        """Create new user from generated user.

        TODO: Should this check to not accidentally create the same user as before?
        """

        # Perform operation on real system
        self._gitlab.create_user(user)

        # Return value store it into bundle
        return user

    @rule(user=created_users)
    @expect_exception(GitlabException)
    def create_existing_user(self, user: User):
        """Test creating an existing user, should raise an exception.
        User is drawn from the `created_users` bundle, guaranteeing it has been created.
        """
        self._gitlab.create_user(user)

    @rule(user=created_users)
    def get_existing_user(self, user: User):
        """Test fetching an existing user, as post-condition both model and states should agree.
        """
        fetched_user = self._gitlab.fetch_user(user.uid)
        assert fetched_user == user

    @rule(user=users())
    def get_non_existing_user(self, user: User):
        """Test fetching an non-existing user, should return None.
        """
        fetched_user = self._gitlab.fetch_user(user.uid)
        assert fetched_user is None

    @rule(user=consumes(created_users))
    def delete_user(self, user: User):
        """Test deleting an existing user. Consumes user from the created users bundle.
        """
        self._gitlab.delete_user(user.uid)
class StatefulDictStateMachine(RuleBasedStateMachine):
    def __init__(self):
        super().__init__()
        self.sorted_dict = SortedDict()
        self.state = {}

    inserted_keys = Bundle("inserted")
    deleted_keys = Bundle("deleted_keys")

    @rule(target=inserted_keys, key=some.integers(), value=some.text())
    def insert(self, key, value):
        event("Inserting key")
        self.sorted_dict[key] = value
        self.state[key] = value

        return key

    @rule(key=inserted_keys)
    def search(self, key):
        # A key inserted before may have already been
        # deleted if it was a duplicate, so searching it
        # may not succeed. Check the key exists in
        # the model dictionary.
        assume(key in self.state)
        event("Searching existing key")
        assert self.sorted_dict[key] == self.state[key]

    @rule(key=consumes(inserted_keys))
    def delete(self, key):
        assume(key in self.state)
        event("Deleting key")
        del self.sorted_dict[key]
        del self.state[key]

    @rule(key=some.integers())
    def search_non_existing(self, key):
        assume(key not in self.state)
        event("Searching non-existing key")
        with pytest.raises(KeyError):  # type: ignore
            self.sorted_dict[key]

    @invariant()
    def keys_sorted(self):
        keys = self.sorted_dict.keys()
        assert keys == sorted(keys)
Example #8
0
class OnChainMixin:

    block_number: BlockNumber

    @rule(number=integers(min_value=1, max_value=50))
    def new_blocks(self, number):
        for _ in range(number):
            block_state_change = Block(
                block_number=BlockNumber(self.block_number + 1),
                gas_limit=BlockGasLimit(1),
                block_hash=make_block_hash(),
            )
            for client in self.address_to_client.values():
                events = list()
                result = node.state_transition(client.chain_state,
                                               block_state_change)
                events.extend(result.events)
            # TODO assert on events

            self.block_number += 1

    @rule(reference=address_pairs, target=address_pairs)
    def open_channel(self, reference):
        return self.new_channel_with_transaction(reference.our_address)

    @rule(address_pair=consumes(address_pairs))
    def settle_channel(self, address_pair):
        client = self.address_to_client[address_pair.our_address]
        channel = client.address_to_channel[address_pair.partner_address]

        channel_settled_state_change = ContractReceiveChannelSettled(
            transaction_hash=factories.make_transaction_hash(),
            canonical_identifier=factories.make_canonical_identifier(
                chain_identifier=channel.chain_id,
                token_network_address=channel.token_network_address,
                channel_identifier=channel.identifier,
            ),
            block_number=self.block_number + 1,
            block_hash=factories.make_block_hash(),
            our_onchain_locksroot=LOCKSROOT_OF_NO_LOCKS,
            partner_onchain_locksroot=LOCKSROOT_OF_NO_LOCKS,
        )

        node.state_transition(client.chain_state, channel_settled_state_change)
Example #9
0
    class StrategyStateMachine(RuleBasedStateMachine):
        def __init__(self):
            super(StrategyStateMachine, self).__init__()
            self.model = set()
            self.agenda = Agenda()
            self.strategy = strategy()
            self.fss = set()

        activations = Bundle("activations")

        @rule(target=activations,
              r=st.integers(min_value=0),
              fs=st.sets(st.integers(min_value=0), min_size=1))
        def declare(self, r, fs):
            assume((r, frozenset(fs)) not in self.fss)
            self.fss.add((r, frozenset(fs)))

            fs = [Fact(i, __factid__=i) for i in fs]
            act = Activation(Rule(Fact(r)), facts=tuple(fs))

            # Update agenda
            self.strategy.update_agenda(self.agenda, [act], [])

            # Update model
            self.model |= set([act])

            return act

        @rule(act=consumes(activations))
        def retract(self, act):
            # Update agenda
            self.strategy.update_agenda(self.agenda, [], [act])

            # Update model
            self.model -= set([act])

        @invariant()
        def values_agree(self):
            assert set(self.agenda.activations) == self.model
Example #10
0
class PtrMixSM(stateful.RuleBasedStateMachine):
    def __init__(self):
        super().__init__()
        self._var = None
        self._contents = {}

    mixed_pointers = stateful.Bundle("mixed pointers")

    @stateful.initialize(target=mixed_pointers, a=pointers(), b=pointers())
    def mixptr(self, a, b):
        self._var = lib.mixptr(a, b)
        self._contents = {a: b, b: a}
        return stateful.multiple(a, b)

    @stateful.invariant()
    def equation(self):
        # nothing to check if called before initialization.
        if self._contents:
            # make it work if a == b and thus _contents has 1 entry
            contents = list(self._contents)
            a, b = contents[0], contents[-1]
            assert self._var.content == lib.mixptr(a, b).content

    @stateful.invariant()
    def unmixptr(self):
        for ptr in self._contents:
            assert lib.unmixptr(self._var, ptr) == self._contents[ptr]

    @stateful.rule(target=mixed_pointers,
                   a=stateful.consumes(mixed_pointers),
                   b=pointers())
    def remixptr(self, a, b):
        lib.remixptr(ffi.addressof(self._var), a, b)
        a = self._contents[a]
        self._contents = {a: b, b: a}
        return b
Example #11
0
def test_consumes_typecheck():
    with pytest.raises(TypeError):
        consumes(integers())
Example #12
0
class MediatorMixin:
    address_to_privkey: Dict[Address, PrivateKey]
    address_to_client: Dict[Address, Client]
    block_number: BlockNumber
    token_id: TokenAddress

    def __init__(self):
        super().__init__()
        self.partner_to_balance_proof_data: Dict[Address,
                                                 BalanceProofData] = dict()
        self.secrethash_to_secret: Dict[SecretHash, Secret] = dict()
        self.waiting_for_unlock: Dict[Secret, Address] = dict()
        self.initial_number_of_channels = 2

    def _get_balance_proof_data(self, partner, client_address):
        if partner not in self.partner_to_balance_proof_data:
            client = self.address_to_client[client_address]
            partner_channel = client.address_to_channel[partner]
            self.partner_to_balance_proof_data[partner] = BalanceProofData(
                canonical_identifier=partner_channel.canonical_identifier)
        return self.partner_to_balance_proof_data[partner]

    def _update_balance_proof_data(self, partner, amount, expiration, secret,
                                   our_address):
        expected = self._get_balance_proof_data(partner, our_address)
        lock = HashTimeLockState(amount=amount,
                                 expiration=expiration,
                                 secrethash=sha256_secrethash(secret))
        expected.update(amount, lock)
        return expected

    init_mediators = Bundle("init_mediators")
    secret_requests = Bundle("secret_requests")
    unlocks = Bundle("unlocks")

    def _new_mediator_transfer(self, initiator_address, target_address,
                               payment_id, amount, secret,
                               our_address) -> LockedTransferSignedState:
        initiator_pkey = self.address_to_privkey[initiator_address]
        balance_proof_data = self._update_balance_proof_data(
            initiator_address, amount, self.block_number + 10, secret,
            our_address)
        self.secrethash_to_secret[sha256_secrethash(secret)] = secret

        return factories.create(
            factories.LockedTransferSignedStateProperties(  # type: ignore
                **balance_proof_data.properties.__dict__,
                amount=amount,
                expiration=BlockExpiration(self.block_number + 10),
                payment_identifier=payment_id,
                secret=secret,
                initiator=initiator_address,
                target=target_address,
                token=self.token_id,
                sender=initiator_address,
                recipient=our_address,
                pkey=initiator_pkey,
                message_identifier=MessageID(1),
            ))

    def _action_init_mediator(self, transfer: LockedTransferSignedState,
                              client_address) -> WithOurAddress:
        client = self.address_to_client[client_address]
        initiator_channel = client.address_to_channel[Address(
            transfer.initiator)]
        target_channel = client.address_to_channel[Address(transfer.target)]
        assert isinstance(target_channel, NettingChannelState)

        action = ActionInitMediator(
            route_states=[factories.make_route_from_channel(target_channel)],
            from_hop=factories.make_hop_to_channel(initiator_channel),
            from_transfer=transfer,
            balance_proof=transfer.balance_proof,
            sender=transfer.balance_proof.sender,
        )
        return WithOurAddress(our_address=client_address, data=action)

    def _unwrap(self, with_our_address: WithOurAddress):
        our_address = with_our_address.our_address
        data = with_our_address.data
        client = self.address_to_client[our_address]
        return data, client, our_address

    @rule(
        target=init_mediators,
        from_channel=address_pairs,
        to_channel=address_pairs,
        payment_id=payment_id(),  # pylint: disable=no-value-for-parameter
        amount=integers(min_value=1, max_value=100),
        secret=secret(),  # pylint: disable=no-value-for-parameter
    )
    def valid_init_mediator(self, from_channel, to_channel, payment_id, amount,
                            secret):
        our_address = from_channel.our_address
        assume(to_channel.our_address ==
               our_address)  # FIXME this will be too slow
        client = self.address_to_client[our_address]

        from_partner = from_channel.partner_address
        to_partner = to_channel.partner_address
        assume(from_partner != to_partner)

        transfer = self._new_mediator_transfer(from_partner, to_partner,
                                               payment_id, amount, secret,
                                               our_address)
        client_data = self._action_init_mediator(transfer, our_address)
        result = node.state_transition(client.chain_state, client_data.data)

        assert event_types_match(result.events, SendProcessed,
                                 SendLockedTransfer)

        return client_data

    @rule(target=secret_requests,
          previous_action_with_address=consumes(init_mediators))
    def valid_receive_secret_reveal(self, previous_action_with_address):
        previous_action, client, our_address = self._unwrap(
            previous_action_with_address)

        secret = self.secrethash_to_secret[
            previous_action.from_transfer.lock.secrethash]
        sender = previous_action.from_transfer.target
        recipient = previous_action.from_transfer.initiator

        action = ReceiveSecretReveal(secret=secret, sender=sender)
        result = node.state_transition(client.chain_state, action)

        expiration = previous_action.from_transfer.lock.expiration
        in_time = self.block_number < expiration - DEFAULT_NUMBER_OF_BLOCK_CONFIRMATIONS
        still_waiting = self.block_number < expiration + DEFAULT_WAIT_BEFORE_LOCK_REMOVAL

        if (in_time and self.channel_opened(sender, our_address)
                and self.channel_opened(recipient, our_address)):
            assert event_types_match(result.events, SendSecretReveal,
                                     SendUnlock, EventUnlockSuccess)
            self.event("Unlock successful.")
            self.waiting_for_unlock[secret] = recipient
        elif still_waiting and self.channel_opened(recipient, our_address):
            assert event_types_match(result.events, SendSecretReveal)
            self.event("Unlock failed, secret revealed too late.")
        else:
            assert not result.events
            self.event(
                "ReceiveSecretRevealed after removal of lock - dropped.")
        return WithOurAddress(our_address=our_address, data=action)

    @rule(previous_action_with_address=secret_requests)
    def replay_receive_secret_reveal(self, previous_action_with_address):
        previous_action, client, _ = self._unwrap(previous_action_with_address)
        result = node.state_transition(client.chain_state, previous_action)
        assert not result.events

    # pylint: disable=no-value-for-parameter
    @rule(previous_action_with_address=secret_requests,
          invalid_sender=address())
    # pylint: enable=no-value-for-parameter
    def replay_receive_secret_reveal_scrambled_sender(
            self, previous_action_with_address, invalid_sender):
        previous_action, client, _ = self._unwrap(previous_action_with_address)
        action = ReceiveSecretReveal(previous_action.secret, invalid_sender)
        result = node.state_transition(client.chain_state, action)
        assert not result.events

    # pylint: disable=no-value-for-parameter
    @rule(previous_action_with_address=init_mediators, secret=secret())
    # pylint: enable=no-value-for-parameter
    def wrong_secret_receive_secret_reveal(self, previous_action_with_address,
                                           secret):
        previous_action, client, _ = self._unwrap(previous_action_with_address)
        sender = previous_action.from_transfer.target
        action = ReceiveSecretReveal(secret, sender)
        result = node.state_transition(client.chain_state, action)
        assert not result.events

    # pylint: disable=no-value-for-parameter
    @rule(
        target=secret_requests,
        previous_action_with_address=consumes(init_mediators),
        invalid_sender=address(),
    )
    # pylint: enable=no-value-for-parameter
    def wrong_address_receive_secret_reveal(self, previous_action_with_address,
                                            invalid_sender):
        previous_action, client, our_address = self._unwrap(
            previous_action_with_address)
        secret = self.secrethash_to_secret[
            previous_action.from_transfer.lock.secrethash]
        invalid_action = ReceiveSecretReveal(secret, invalid_sender)
        result = node.state_transition(client.chain_state, invalid_action)
        assert not result.events

        valid_sender = previous_action.from_transfer.target
        valid_action = ReceiveSecretReveal(secret, valid_sender)
        return WithOurAddress(our_address=our_address, data=valid_action)
Example #13
0
class AuthStateMachine(RuleBasedStateMachine):
    """
    State machine for auth flows

    How to understand this code:

    This code exercises our social auth APIs, which is basically a graph of nodes and edges that the user traverses.
    You can understand the bundles defined below to be the nodes and the methods of this class to be the edges.

    If you add a new state to the auth flows, create a new bundle to represent that state and define
    methods to define transitions into and (optionally) out of that state.
    """

    # pylint: disable=too-many-instance-attributes

    ConfirmationSentAuthStates = Bundle("confirmation-sent")
    ConfirmationRedeemedAuthStates = Bundle("confirmation-redeemed")
    RegisterExtraDetailsAuthStates = Bundle("register-details-extra")

    LoginPasswordAuthStates = Bundle("login-password")
    LoginPasswordAbandonedAuthStates = Bundle("login-password-abandoned")

    recaptcha_patcher = patch(
        "authentication.views.requests.post",
        return_value=MockResponse(content='{"success": true}',
                                  status_code=status.HTTP_200_OK),
    )
    email_send_patcher = patch("mail.verification_api.send_verification_email",
                               autospec=True)
    courseware_api_patcher = patch(
        "authentication.pipeline.user.courseware_api")
    courseware_tasks_patcher = patch(
        "authentication.pipeline.user.courseware_tasks")

    def __init__(self):
        """Setup the machine"""
        super().__init__()
        # wrap the execution in a django transaction, similar to django's TestCase
        self.atomic = transaction.atomic()
        self.atomic.__enter__()

        # wrap the execution in a patch()
        self.mock_email_send = self.email_send_patcher.start()
        self.mock_courseware_api = self.courseware_api_patcher.start()
        self.mock_courseware_tasks = self.courseware_tasks_patcher.start()

        # django test client
        self.client = Client()

        # shared data
        self.email = fake.email()
        self.user = None
        self.password = "******"

        # track whether we've hit an action that starts a flow or not
        self.flow_started = False

    def teardown(self):
        """Cleanup from a run"""
        # clear the mailbox
        del mail.outbox[:]

        # stop the patches
        self.email_send_patcher.stop()
        self.courseware_api_patcher.stop()
        self.courseware_tasks_patcher.stop()

        # end the transaction with a rollback to cleanup any state
        transaction.set_rollback(True)
        self.atomic.__exit__(None, None, None)

    def create_existing_user(self):
        """Create an existing user"""
        self.user = UserFactory.create(email=self.email)
        self.user.set_password(self.password)
        self.user.save()
        UserSocialAuthFactory.create(user=self.user,
                                     provider=EmailAuth.name,
                                     uid=self.user.email)

    @rule(
        target=ConfirmationSentAuthStates,
        recaptcha_enabled=st.sampled_from([True, False]),
    )
    @precondition(lambda self: not self.flow_started)
    def register_email_not_exists(self, recaptcha_enabled):
        """Register email not exists"""
        self.flow_started = True

        with ExitStack() as stack:
            mock_recaptcha_success = None
            if recaptcha_enabled:
                mock_recaptcha_success = stack.enter_context(
                    self.recaptcha_patcher)
                stack.enter_context(
                    override_settings(**{"RECAPTCHA_SITE_KEY": "fake"}))
            result = assert_api_call(
                self.client,
                "psa-register-email",
                {
                    "flow": SocialAuthState.FLOW_REGISTER,
                    "email": self.email,
                    **({
                        "recaptcha": "fake"
                    } if recaptcha_enabled else {}),
                },
                {
                    "flow": SocialAuthState.FLOW_REGISTER,
                    "partial_token": None,
                    "state": SocialAuthState.STATE_REGISTER_CONFIRM_SENT,
                },
            )
            self.mock_email_send.assert_called_once()
            if mock_recaptcha_success:
                mock_recaptcha_success.assert_called_once()
            return result

    @rule(target=LoginPasswordAuthStates,
          recaptcha_enabled=st.sampled_from([True, False]))
    @precondition(lambda self: not self.flow_started)
    def register_email_exists(self, recaptcha_enabled):
        """Register email exists"""
        self.flow_started = True
        self.create_existing_user()

        with ExitStack() as stack:
            mock_recaptcha_success = None
            if recaptcha_enabled:
                mock_recaptcha_success = stack.enter_context(
                    self.recaptcha_patcher)
                stack.enter_context(
                    override_settings(**{"RECAPTCHA_SITE_KEY": "fake"}))

            result = assert_api_call(
                self.client,
                "psa-register-email",
                {
                    "flow": SocialAuthState.FLOW_REGISTER,
                    "email": self.email,
                    "next": NEXT_URL,
                    **({
                        "recaptcha": "fake"
                    } if recaptcha_enabled else {}),
                },
                {
                    "flow": SocialAuthState.FLOW_REGISTER,
                    "state": SocialAuthState.STATE_LOGIN_PASSWORD,
                    "errors": ["Password is required to login"],
                },
            )
            self.mock_email_send.assert_not_called()
            if mock_recaptcha_success:
                mock_recaptcha_success.assert_called_once()
            return result

    @rule()
    @precondition(lambda self: not self.flow_started)
    def register_email_not_exists_with_recaptcha_invalid(self):
        """Yield a function for this step"""
        self.flow_started = True
        with patch(
                "authentication.views.requests.post",
                return_value=MockResponse(
                    content=
                    '{"success": false, "error-codes": ["bad-request"]}',
                    status_code=status.HTTP_200_OK,
                ),
        ) as mock_recaptcha_failure, override_settings(
                **{"RECAPTCHA_SITE_KEY": "fakse"}):
            assert_api_call(
                self.client,
                "psa-register-email",
                {
                    "flow": SocialAuthState.FLOW_REGISTER,
                    "email": NEW_EMAIL,
                    "recaptcha": "fake",
                },
                {
                    "error-codes": ["bad-request"],
                    "success": False
                },
                expect_status=status.HTTP_400_BAD_REQUEST,
                use_defaults=False,
            )
            mock_recaptcha_failure.assert_called_once()
            self.mock_email_send.assert_not_called()

    @rule()
    @precondition(lambda self: not self.flow_started)
    def login_email_not_exists(self):
        """Login for an email that doesn't exist"""
        self.flow_started = True
        assert_api_call(
            self.client,
            "psa-login-email",
            {
                "flow": SocialAuthState.FLOW_LOGIN,
                "email": self.email
            },
            {
                "field_errors": {
                    "email": "Couldn't find your account"
                },
                "flow": SocialAuthState.FLOW_LOGIN,
                "partial_token": None,
                "state": SocialAuthState.STATE_REGISTER_REQUIRED,
            },
        )
        assert User.objects.filter(email=self.email).exists() is False

    @rule(target=LoginPasswordAuthStates)
    @precondition(lambda self: not self.flow_started)
    def login_email_exists(self):
        """Login with a user that exists"""
        self.flow_started = True
        self.create_existing_user()

        return assert_api_call(
            self.client,
            "psa-login-email",
            {
                "flow": SocialAuthState.FLOW_LOGIN,
                "email": self.user.email,
                "next": NEXT_URL,
            },
            {
                "flow": SocialAuthState.FLOW_LOGIN,
                "state": SocialAuthState.STATE_LOGIN_PASSWORD,
                "extra_data": {
                    "name": self.user.name
                },
            },
        )

    @rule(
        target=LoginPasswordAbandonedAuthStates,
        auth_state=consumes(RegisterExtraDetailsAuthStates),
    )
    @precondition(lambda self: self.flow_started)
    def login_email_abandoned(self, auth_state):  # pylint: disable=unused-argument
        """Login with a user that abandoned the register flow"""
        # NOTE: This works by "consuming" an extra details auth state,
        #       but discarding the state and starting a new login.
        #       It then re-targets the new state into the extra details again.
        auth_state = None  # assign None to ensure no accidental usage here

        return assert_api_call(
            self.client,
            "psa-login-email",
            {
                "flow": SocialAuthState.FLOW_LOGIN,
                "email": self.user.email,
                "next": NEXT_URL,
            },
            {
                "flow": SocialAuthState.FLOW_LOGIN,
                "state": SocialAuthState.STATE_LOGIN_PASSWORD,
                "extra_data": {
                    "name": self.user.name
                },
            },
        )

    @rule(
        target=RegisterExtraDetailsAuthStates,
        auth_state=consumes(LoginPasswordAbandonedAuthStates),
    )
    def login_password_abandoned(self, auth_state):
        """Login with an abandoned registration user"""
        return assert_api_call(
            self.client,
            "psa-login-password",
            {
                "flow": auth_state["flow"],
                "partial_token": auth_state["partial_token"],
                "password": self.password,
            },
            {
                "flow": auth_state["flow"],
                "state": SocialAuthState.STATE_REGISTER_EXTRA_DETAILS,
            },
        )

    @rule(auth_state=consumes(LoginPasswordAuthStates))
    def login_password_valid(self, auth_state):
        """Login with a valid password"""
        assert_api_call(
            self.client,
            "psa-login-password",
            {
                "flow": auth_state["flow"],
                "partial_token": auth_state["partial_token"],
                "password": self.password,
            },
            {
                "flow": auth_state["flow"],
                "redirect_url": NEXT_URL,
                "partial_token": None,
                "state": SocialAuthState.STATE_SUCCESS,
            },
            expect_authenticated=True,
        )

    @rule(target=LoginPasswordAuthStates,
          auth_state=consumes(LoginPasswordAuthStates))
    def login_password_invalid(self, auth_state):
        """Login with an invalid password"""
        return assert_api_call(
            self.client,
            "psa-login-password",
            {
                "flow": auth_state["flow"],
                "partial_token": auth_state["partial_token"],
                "password": "******",
            },
            {
                "field_errors": {
                    "password":
                    "******"
                },
                "flow": auth_state["flow"],
                "state": SocialAuthState.STATE_ERROR,
            },
        )

    @rule(
        auth_state=consumes(LoginPasswordAuthStates),
        verify_exports=st.sampled_from([True, False]),
    )
    def login_password_user_inactive(self, auth_state, verify_exports):
        """Login for an inactive user"""
        self.user.is_active = False
        self.user.save()

        cm = export_check_response("100_success") if verify_exports else noop()

        with cm:
            assert_api_call(
                self.client,
                "psa-login-password",
                {
                    "flow": auth_state["flow"],
                    "partial_token": auth_state["partial_token"],
                    "password": self.password,
                },
                {
                    "flow": auth_state["flow"],
                    "redirect_url": NEXT_URL,
                    "partial_token": None,
                    "state": SocialAuthState.STATE_SUCCESS,
                },
                expect_authenticated=True,
            )

    @rule(auth_state=consumes(LoginPasswordAuthStates))
    def login_password_exports_temporary_error(self, auth_state):
        """Login for a user who hasn't been OFAC verified yet"""
        with override_settings(**get_cybersource_test_settings()), patch(
                "authentication.pipeline.compliance.api.verify_user_with_exports",
                side_effect=Exception(
                    "register_details_export_temporary_error"),
        ):
            assert_api_call(
                self.client,
                "psa-login-password",
                {
                    "flow": auth_state["flow"],
                    "partial_token": auth_state["partial_token"],
                    "password": self.password,
                },
                {
                    "flow":
                    auth_state["flow"],
                    "partial_token":
                    None,
                    "state":
                    SocialAuthState.STATE_ERROR_TEMPORARY,
                    "errors": [
                        "Unable to register at this time, please try again later"
                    ],
                },
            )

    @rule(
        target=ConfirmationRedeemedAuthStates,
        auth_state=consumes(ConfirmationSentAuthStates),
    )
    def redeem_confirmation_code(self, auth_state):
        """Redeem a registration confirmation code"""
        _, _, code, partial_token = self.mock_email_send.call_args[0]
        return assert_api_call(
            self.client,
            "psa-register-confirm",
            {
                "flow": auth_state["flow"],
                "verification_code": code.code,
                "partial_token": partial_token,
            },
            {
                "flow": auth_state["flow"],
                "state": SocialAuthState.STATE_REGISTER_DETAILS,
            },
        )

    @rule(auth_state=consumes(ConfirmationRedeemedAuthStates))
    def redeem_confirmation_code_twice(self, auth_state):
        """Redeeming a code twice should fail"""
        _, _, code, partial_token = self.mock_email_send.call_args[0]
        assert_api_call(
            self.client,
            "psa-register-confirm",
            {
                "flow": auth_state["flow"],
                "verification_code": code.code,
                "partial_token": partial_token,
            },
            {
                "errors": [],
                "flow": auth_state["flow"],
                "redirect_url": None,
                "partial_token": None,
                "state": SocialAuthState.STATE_INVALID_LINK,
            },
        )

    @rule(auth_state=consumes(ConfirmationRedeemedAuthStates))
    def redeem_confirmation_code_twice_existing_user(self, auth_state):
        """Redeeming a code twice with an existing user should fail with existing account state"""
        _, _, code, partial_token = self.mock_email_send.call_args[0]
        self.create_existing_user()
        assert_api_call(
            self.client,
            "psa-register-confirm",
            {
                "flow": auth_state["flow"],
                "verification_code": code.code,
                "partial_token": partial_token,
            },
            {
                "errors": [],
                "flow": auth_state["flow"],
                "redirect_url": None,
                "partial_token": None,
                "state": SocialAuthState.STATE_EXISTING_ACCOUNT,
            },
        )

    @rule(
        target=RegisterExtraDetailsAuthStates,
        auth_state=consumes(ConfirmationRedeemedAuthStates),
    )
    def register_details(self, auth_state):
        """Complete the register confirmation details page"""
        result = assert_api_call(
            self.client,
            "psa-register-details",
            {
                "flow": auth_state["flow"],
                "partial_token": auth_state["partial_token"],
                "password": self.password,
                "name": "Sally Smith",
                "legal_address": {
                    "first_name": "Sally",
                    "last_name": "Smith",
                    "street_address": ["Main Street"],
                    "country": "US",
                    "state_or_territory": "US-CO",
                    "city": "Boulder",
                    "postal_code": "02183",
                },
            },
            {
                "flow": auth_state["flow"],
                "state": SocialAuthState.STATE_REGISTER_EXTRA_DETAILS,
            },
        )
        self.user = User.objects.get(email=self.email)
        return result

    @rule(
        target=RegisterExtraDetailsAuthStates,
        auth_state=consumes(ConfirmationRedeemedAuthStates),
    )
    def register_details_export_success(self, auth_state):
        """Complete the register confirmation details page with exports enabled"""
        with export_check_response("100_success"):
            result = assert_api_call(
                self.client,
                "psa-register-details",
                {
                    "flow": auth_state["flow"],
                    "partial_token": auth_state["partial_token"],
                    "password": self.password,
                    "name": "Sally Smith",
                    "legal_address": {
                        "first_name": "Sally",
                        "last_name": "Smith",
                        "street_address": ["Main Street"],
                        "country": "US",
                        "state_or_territory": "US-CO",
                        "city": "Boulder",
                        "postal_code": "02183",
                    },
                },
                {
                    "flow": auth_state["flow"],
                    "state": SocialAuthState.STATE_REGISTER_EXTRA_DETAILS,
                },
            )
            assert ExportsInquiryLog.objects.filter(
                user__email=self.email).exists()
            assert (ExportsInquiryLog.objects.get(
                user__email=self.email).computed_result == RESULT_SUCCESS)
            assert len(mail.outbox) == 0

            self.user = User.objects.get(email=self.email)
            return result

    @rule(auth_state=consumes(ConfirmationRedeemedAuthStates))
    def register_details_export_reject(self, auth_state):
        """Complete the register confirmation details page with exports enabled"""
        with export_check_response("700_reject"):
            assert_api_call(
                self.client,
                "psa-register-details",
                {
                    "flow": auth_state["flow"],
                    "partial_token": auth_state["partial_token"],
                    "password": self.password,
                    "name": "Sally Smith",
                    "legal_address": {
                        "first_name": "Sally",
                        "last_name": "Smith",
                        "street_address": ["Main Street"],
                        "country": "US",
                        "state_or_territory": "US-CO",
                        "city": "Boulder",
                        "postal_code": "02183",
                    },
                },
                {
                    "flow": auth_state["flow"],
                    "partial_token": None,
                    "errors": ["Error code: CS_700"],
                    "state": SocialAuthState.STATE_USER_BLOCKED,
                },
            )
            assert ExportsInquiryLog.objects.filter(
                user__email=self.email).exists()
            assert (ExportsInquiryLog.objects.get(
                user__email=self.email).computed_result == RESULT_DENIED)
            assert len(mail.outbox) == 1

    @rule(auth_state=consumes(ConfirmationRedeemedAuthStates))
    def register_details_export_temporary_error(self, auth_state):
        """Complete the register confirmation details page with exports raising a temporary error"""
        with override_settings(**get_cybersource_test_settings()), patch(
                "authentication.pipeline.compliance.api.verify_user_with_exports",
                side_effect=Exception(
                    "register_details_export_temporary_error"),
        ):
            assert_api_call(
                self.client,
                "psa-register-details",
                {
                    "flow": auth_state["flow"],
                    "partial_token": auth_state["partial_token"],
                    "password": self.password,
                    "name": "Sally Smith",
                    "legal_address": {
                        "first_name": "Sally",
                        "last_name": "Smith",
                        "street_address": ["Main Street"],
                        "country": "US",
                        "state_or_territory": "US-CO",
                        "city": "Boulder",
                        "postal_code": "02183",
                    },
                },
                {
                    "flow":
                    auth_state["flow"],
                    "partial_token":
                    None,
                    "errors": [
                        "Unable to register at this time, please try again later"
                    ],
                    "state":
                    SocialAuthState.STATE_ERROR_TEMPORARY,
                },
            )
            assert not ExportsInquiryLog.objects.filter(
                user__email=self.email).exists()
            assert len(mail.outbox) == 0

    @rule(auth_state=consumes(RegisterExtraDetailsAuthStates))
    def register_user_extra_details(self, auth_state):
        """Complete the user's extra details"""
        assert_api_call(
            Client(),
            "psa-register-extra",
            {
                "flow": auth_state["flow"],
                "partial_token": auth_state["partial_token"],
                "gender": "f",
                "birth_year": "2000",
                "company": "MIT",
                "job_title": "QA Manager",
            },
            {
                "flow": auth_state["flow"],
                "state": SocialAuthState.STATE_SUCCESS,
                "partial_token": None,
            },
            expect_authenticated=True,
        )
    class TrustchainValidate(RuleBasedStateMachine):
        NonRevokedAdminUsers = Bundle("admin users")
        NonRevokedOtherUsers = Bundle("other users")
        RevokedUsers = Bundle("revoked users")

        def next_user_id(self):
            nonlocal name_count
            name_count += 1
            return UserID(f"user{name_count}")

        def next_device_id(self, user_id=None):
            nonlocal name_count
            user_id = user_id or self.next_user_id()
            name_count += 1
            return user_id.to_device_id(DeviceName(f"dev{name_count}"))

        def new_user_and_device(self, is_admin, certifier_id, certifier_key):
            device_id = self.next_device_id()

            local_device = local_device_factory(device_id, org=coolorg)
            self.local_devices[device_id] = local_device

            user = UserCertificateContent(
                author=certifier_id,
                timestamp=pendulum_now(),
                user_id=local_device.user_id,
                human_handle=local_device.human_handle,
                public_key=local_device.public_key,
                profile=UserProfile.ADMIN
                if is_admin else UserProfile.STANDARD,
            )
            self.users_content[device_id.user_id] = user
            self.users_certifs[device_id.user_id] = user.dump_and_sign(
                certifier_key)

            device = DeviceCertificateContent(
                author=certifier_id,
                timestamp=pendulum_now(),
                device_id=local_device.device_id,
                device_label=local_device.device_label,
                verify_key=local_device.verify_key,
            )
            self.devices_content[local_device.device_id] = device
            self.devices_certifs[
                local_device.device_id] = device.dump_and_sign(certifier_key)

            return device_id

        @initialize(target=NonRevokedAdminUsers)
        def init(self):
            caplog.clear()
            self.users_certifs = {}
            self.users_content = {}
            self.revoked_users_certifs = {}
            self.revoked_users_content = {}
            self.devices_certifs = {}
            self.devices_content = {}
            self.local_devices = {}
            device_id = self.new_user_and_device(
                is_admin=True,
                certifier_id=None,
                certifier_key=coolorg.root_signing_key)
            note(f"new device: {device_id}")
            return device_id.user_id

        def get_device(self, user_id, device_rand):
            user_devices = [
                device for device_id, device in self.local_devices.items()
                if device_id.user_id == user_id
            ]
            return user_devices[device_rand % len(user_devices)]

        @rule(
            target=NonRevokedAdminUsers,
            author_user=NonRevokedAdminUsers,
            author_device_rand=st.integers(min_value=0),
        )
        def new_admin_user(self, author_user, author_device_rand):
            author = self.get_device(author_user, author_device_rand)
            device_id = self.new_user_and_device(
                is_admin=True,
                certifier_id=author.device_id,
                certifier_key=author.signing_key)
            note(f"new device: {device_id} (author: {author.device_id})")
            return device_id.user_id

        @rule(
            target=NonRevokedOtherUsers,
            author_user=NonRevokedAdminUsers,
            author_device_rand=st.integers(min_value=0),
        )
        def new_non_admin_user(self, author_user, author_device_rand):
            author = self.get_device(author_user, author_device_rand)
            device_id = self.new_user_and_device(
                is_admin=False,
                certifier_id=author.device_id,
                certifier_key=author.signing_key)
            note(f"new device: {device_id} (author: {author.device_id})")
            return device_id.user_id

        @precondition(lambda self: len(
            [d for d in self.local_devices.values() if d.is_admin]) > 1)
        @rule(
            target=RevokedUsers,
            user=st.one_of(consumes(NonRevokedAdminUsers),
                           consumes(NonRevokedOtherUsers)),
            author_rand=st.integers(min_value=0),
        )
        def revoke_user(self, user, author_rand):
            possible_authors = [
                device for device_id, device in self.local_devices.items()
                if device_id.user_id != user
                and device.profile == UserProfile.ADMIN
            ]
            author = possible_authors[author_rand % len(possible_authors)]
            note(f"revoke user: {user} (author: {author.device_id})")
            revoked_user = RevokedUserCertificateContent(
                author=author.device_id,
                timestamp=pendulum_now(),
                user_id=user)
            self.revoked_users_content[user] = revoked_user
            self.revoked_users_certifs[user] = revoked_user.dump_and_sign(
                author.signing_key)
            return user

        @rule(
            user=st.one_of(NonRevokedAdminUsers, NonRevokedOtherUsers),
            author_user=NonRevokedAdminUsers,
            author_device_rand=st.integers(min_value=0),
        )
        def new_device(self, user, author_user, author_device_rand):
            author = self.get_device(author_user, author_device_rand)
            device_id = self.next_device_id(user)
            note(f"new device: {device_id} (author: {author.device_id})")
            local_device = local_device_factory(device_id, org=coolorg)
            device = DeviceCertificateContent(
                author=author.device_id,
                timestamp=pendulum_now(),
                device_id=local_device.device_id,
                device_label=local_device.device_label,
                verify_key=local_device.verify_key,
            )
            self.devices_content[local_device.device_id] = device
            self.devices_certifs[
                local_device.device_id] = device.dump_and_sign(
                    author.signing_key)

        @rule(user=st.one_of(NonRevokedAdminUsers, NonRevokedOtherUsers))
        def load_trustchain(self, user):
            ctx = TrustchainContext(coolorg.root_verify_key, 1)

            user_certif = next(
                certif for user_id, certif in self.users_certifs.items()
                if user_id == user)
            revoked_user_certif = next(
                (certif
                 for user_id, certif in self.revoked_users_certifs.items()
                 if user_id == user),
                None,
            )
            devices_certifs = [
                certif for device_id, certif in self.devices_certifs.items()
                if device_id.user_id == user
            ]
            user_content, revoked_user_content, devices_contents = ctx.load_user_and_devices(
                trustchain={
                    "users":
                    [certif for certif in self.users_certifs.values()],
                    "revoked_users":
                    [certif for certif in self.revoked_users_certifs.values()],
                    "devices":
                    [certif for certif in self.devices_certifs.values()],
                },
                user_certif=user_certif,
                revoked_user_certif=revoked_user_certif,
                devices_certifs=devices_certifs,
                expected_user_id=user,
            )

            expected_user_content = next(
                content for user_id, content in self.users_content.items()
                if user_id == user)
            expected_revoked_user_content = next(
                (content
                 for user_id, content in self.revoked_users_content.items()
                 if user_id == user),
                None,
            )
            expected_devices_contents = [
                content for device_id, content in self.devices_content.items()
                if device_id.user_id == user
            ]
            assert user_content == expected_user_content
            assert revoked_user_content == expected_revoked_user_content
            assert sorted(devices_contents,
                          key=lambda device: device.device_id) == sorted(
                              expected_devices_contents,
                              key=lambda device: device.device_id)
Example #15
0
def test_deprecated_target_consumes_bundle():
    # It would be nicer to raise this error at runtime, but the internals make
    # this sadly impractical.  Most InvalidDefinition errors happen at, well,
    # definition-time already anyway, so it's not *worse* than the status quo.
    with validate_deprecation():
        rule(target=consumes(Bundle("b")))
class StatefulTestFileGenerator(RuleBasedStateMachine):
    model_object = Bundle("model_object")

    name = Bundle("name")
    constants = Bundle("constants")
    fields = Bundle("fields")
    meta = Bundle("meta")

    @initialize()
    def remove_generated_file(self):
        if os.path.isfile(FILE):
            os.remove(FILE)

    @rule(target=name, name=fake_class_name())
    def add_name(self, name):
        assume(not model_exists(name))
        return name

    @rule(target=constants, constants=fake_constants())
    def add_constants(self, constants):
        return constants

    @rule(target=fields, fields=fake_fields_data())
    def add_fields(self, fields):
        return fields

    @rule(target=meta, meta=default_meta())
    def add_meta(self, meta):
        return meta

    @rule(
        target=model_object,
        name=consumes(name),
        constants=constants,
        fields=fields,
        meta=meta,
    )
    def add_model_object(self, name, constants, fields, meta):
        # Remove Duplicates Fields.
        for field in fields:
            constants.pop(field, None)

        try:
            django_model = get_django_model(name=name,
                                            constants=constants,
                                            fields=fields,
                                            meta=meta)
            model_object = get_model_object(django_model)
        except Exception as e:
            pytest.fail(e)
        else:
            return model_object

    @rule(original=consumes(model_object), tester=consumes(model_object))
    def assert_file_generator(self, original, tester):
        initial_file = []
        if os.path.isfile(FILE):
            event("assert_file_generator: File already exists.")
            with open(FILE, "r") as f:
                initial_file = f.read().splitlines()
        else:
            event("assert_file_generator: File doesn't exists.")

        file_generator_instance = FileGenerator(original, tester)

        # Test File exists.
        assert os.path.isfile(FILE)
        with open(FILE, "r") as f:
            modified_file = f.read().splitlines()

        # Test File has Header.
        assert all(modified_file[index] == line
                   for index, line in enumerate(FILE_HEADER.splitlines()))

        # Test initial data isn't modified.
        for line_number, line in enumerate(initial_file):
            assert line == modified_file[line_number]

        # Test names are corrects and attributes exists.
        appended_data = modified_file[len(initial_file):]
        pattern = r"assert (?P<original>\S+) == (?P<tester>\S+), assert_msg(.+)"
        for line in appended_data:
            assert_line = re.search(pattern, line)
            if assert_line:
                assert_line = {
                    model: {
                        "name": breadcrumb.split(".")[0],
                        "attr": ".".join(breadcrumb.split(".")[1:]),
                    }
                    for model, breadcrumb in assert_line.groupdict().items()
                }

                # Test names are corrects.
                assert assert_line["original"]["name"] == original._meta.name
                assert assert_line["tester"]["name"] == tester._meta.name

                # Test attributes compared are the same.
                assert assert_line["original"]["attr"] == assert_line[
                    "tester"]["attr"]

                # Test attributes exists.
                assert hasattrs(original, assert_line["original"]["attr"])
                assert hasattrs(tester, assert_line["tester"]["attr"])
        # Try retrieve generated functions.
        try:
            generated_functions = file_generator_instance.get_functions()
        except Exception as e:
            pytest.fail(e)

        # Test Module Import.
        try:
            module = sys.modules[MODULE]

            # Test Module has the Generated Class.
            assert hasattr(module, tester._meta.name)

            # Test Generated Class has the generated functions.
            generated_class = getattr(module, tester._meta.name)
            for generated_function in generated_functions.keys():
                assert hasattr(generated_class, generated_function)
        except KeyError as e:
            pytest.fail(e)
Example #17
0
class ListSM(RuleBasedStateMachine):
    def __init__(self):
        self._var = ffi.new('struct list*')
        self._model = deque()
        self._model_contents = deque()
        super().__init__()

    first = alias_property()
    last = alias_property()

    def teardown(self):
        lib.list_clear(self._var)

    class Iterator(typing.Iterator[None]):
        cur = alias_property()
        prev = alias_property()

    def _make_iter(self, reverse: bool = False) -> ListSM.Iterator:
        var = ffi.new('struct list_it*')
        var.cur = self.last if reverse else self.first
        def it():
            while var.cur != ffi.NULL:
                yield  # yield None because we're mutable
                lib.list_step(var)
        it = it()
        class _It(ListSM.Iterator):
            _var = var
            def __next__(self) -> None:
                return next(it)
        return _It()

    def __iter__(self) -> ListSM.Iterator:
        return self._make_iter()

    def __reversed__(self) -> ListSM.Iterator:
        return self._make_iter(reverse=True)

    nodes = Bundle('Nodes')

    @rule(new_value=elements(), target=nodes)
    def insert_front(self, new_value):
        self._model_contents.appendleft(new_value)
        lib.list_insert_front(self._var, new_value)
        new_node = self.first
        assert new_node.value == new_value
        self._model.appendleft(new_node)
        return new_node

    @rule(new_value=elements(), target=nodes)
    def insert_back(self, new_value):
        self._model_contents.append(new_value)
        lib.list_insert_back(self._var, new_value)
        new_node = self.last
        assert new_node.value == new_value
        self._model.append(new_node)
        return new_node

    @rule(nodes=strategies.frozensets(consumes(nodes)), reverse=strategies.booleans())
    def remove_thru_iter(self, nodes, reverse):
        it = reversed(self) if reverse else iter(self)
        for _ in it:
            if it.cur in nodes:
                lib.list_remove(self._var, it._var)
        for n in nodes:
            i = self._model.index(n)
            del self._model_contents[i]
            del self._model[i]

    @invariant()
    def nodes_as_model(self):
        it = iter(self)
        nodes = [it.cur for _ in it]
        assert nodes == list(self._model)

    @invariant()
    def contents_as_model(self):
        it = iter(self)
        contents = [it.cur.value for _ in it]
        assert contents == list(self._model_contents)
def test_consumes_typecheck():
    with pytest.raises(TypeError):
        consumes(integers())
class InventoryStateMachine(RuleBasedStateMachine):
    def __init__(self):
        super(InventoryStateMachine, self).__init__()
        with clientContext() as client:
            self.client = client
            self.model_skus = {}
            self.model_bins = {}
            self.model_batches = {}
            self.model_users = {}
            self.logged_in_as = None

    a_bin_id = Bundle("bin_id")
    a_sku_id = Bundle("sku_id")
    a_batch_id = Bundle("batch_id")
    a_user_id = Bundle("user_id")

    @rule(target=a_user_id, user=dst.users_())
    def new_user(self, user):
        resp = self.client.post("/api/users", json=user)
        if user["id"] in self.model_users.keys():
            assert resp.status_code == 409
            assert resp.is_json
            assert resp.json['type'] == "duplicate-resource"
            return multiple()
        else:
            assert resp.status_code == 201
            self.model_users[user["id"]] = user
            return user["id"]

    @rule(user_id=consumes(a_user_id))
    def delete_existing_user(self, user_id):
        resp = self.client.delete(f"/api/user/{user_id}")
        del self.model_users[user_id]
        assert resp.status_code == 200

    @rule(user_id=a_user_id)
    def get_existing_user(self, user_id):
        resp = self.client.get(f"/api/user/{user_id}")
        assert resp.status_code == 200
        assert resp.is_json
        found_user = resp.json["state"]
        model_user = self.model_users[user_id]
        assert model_user["id"] == found_user["id"]
        assert model_user["name"] == found_user["name"]

    @rule(user_id=dst.ids)
    def get_missing_user(self, user_id):
        assume(user_id not in self.model_users.keys())
        resp = self.client.get(f"/api/user/{user_id}")
        assert resp.status_code == 404
        assert resp.is_json
        assert resp.json['type'] == 'missing-resource'

    @rule(user_id=dst.ids)
    def delete_missing_user(self, user_id):
        assume(user_id not in self.model_users.keys())
        resp = self.client.delete(f"/api/user/{user_id}")
        assert resp.status_code == 404
        assert resp.is_json
        assert resp.json['type'] == "missing-resource"

    @rule(user_id=a_user_id, data=st.data())
    def create_existing_user(self, user_id, data):
        user = data.draw(dst.users_(id=user_id))
        resp = self.client.post("/api/users", json=user)
        assert resp.status_code == 409
        assert resp.is_json
        assert resp.json["type"] == "duplicate-resource"

    user_patch = st.builds(
        lambda user, use_keys:
        {k: v
         for k, v in user.items() if k in use_keys}, dst.users_(),
        st.sets(st.sampled_from([
            "name",
            "password",
        ])))

    @rule(user_id=a_user_id, user_patch=user_patch)
    def update_existing_user(self, user_id, user_patch):
        rp = self.client.patch(f'/api/user/{user_id}', json=user_patch)
        assert rp.status_code == 200
        assert rp.cache_control.no_cache
        for key in user_patch.keys():
            self.model_users[user_id][key] = user_patch[key]

    @rule(user_id=a_user_id)
    def login_as(self, user_id):
        rp = self.client.post("/api/login",
                              json={
                                  "id": user_id,
                                  "password":
                                  self.model_users[user_id]["password"]
                              })
        assert rp.cache_control.no_cache
        self.logged_in_as = user_id

    @rule(user_id=a_user_id, password=st.text())
    def login_bad_password(self, user_id, password):
        assume(password != self.model_users[user_id])
        rp = self.client.post("/api/login",
                              json={
                                  "id": user_id,
                                  "password": password
                              })
        assert rp.status_code == 401
        assert rp.cache_control.no_cache
        assert rp.is_json

    @rule(user=dst.users_())
    def login_bad_username(self, user):
        assume(user['id'] not in self.model_users)
        rp = self.client.post("/api/login",
                              json={
                                  "id": user["id"],
                                  "password": user["password"]
                              })
        assert rp.status_code == 401
        assert rp.cache_control.no_cache
        assert rp.is_json

    @rule()
    def logout(self):
        rp = self.client.post("/api/logout")
        assert rp.status_code == 200
        assert rp.cache_control.no_cache
        assert rp.is_json
        self.logged_in_as = None

    @rule()
    def whoami(self):
        rp = self.client.get("/api/whoami")
        assert rp.status_code == 200
        assert rp.is_json
        if self.logged_in_as:
            assert rp.json["id"] == self.logged_in_as
        else:
            assert rp.json["id"] == None

    @rule(target=a_bin_id, bin=dst.bins_())
    def new_bin(self, bin):
        resp = self.client.post('/api/bins',
                                json=bin.to_dict(mask_default=True))
        if bin.id in self.model_bins.keys():
            assert resp.status_code == 409
            assert resp.is_json
            assert resp.json['type'] == 'duplicate-resource'
            return multiple()
        else:
            assert resp.status_code == 201
            self.model_bins[bin.id] = bin
            return bin.id

    @rule(bin_id=a_bin_id)
    def get_existing_bin(self, bin_id):
        assert bin_id in self.model_bins.keys()

        rp = self.client.get(f'/api/bin/{bin_id}')
        assert rp.status_code == 200
        assert rp.is_json
        assert self.model_bins[bin_id].props == rp.json['state'].get('props')
        found_bin = Bin.from_json(rp.json['state'])
        assert found_bin == self.model_bins[bin_id]

    @rule(bin_id=dst.label_("BIN"))
    def get_missing_bin(self, bin_id):
        assume(bin_id not in self.model_bins.keys())
        rp = self.client.get(f'/api/bin/{bin_id}')
        assert rp.status_code == 404
        assert rp.json['type'] == 'missing-resource'

    @rule(bin_id=a_bin_id, newProps=dst.propertyDicts)
    def update_bin(self, bin_id, newProps):
        # assume(self.model_bins[bin_id].props != newProps)
        rp = self.client.patch(f'/api/bin/{bin_id}',
                               json={
                                   "id": bin_id,
                                   "props": newProps
                               })
        self.model_bins[bin_id].props = newProps
        assert rp.status_code == 200
        assert rp.cache_control.no_cache

    @rule(bin_id=a_bin_id, newProps=dst.json)
    def update_missing_bin(self, bin_id, newProps):
        assume(bin_id not in self.model_bins.keys())
        rp = self.client.put(f'/api/bin/{bin_id}/props', json=newProps)
        assert rp.status_code == 404
        assert rp.json['type'] == 'missing-resource'

    @rule(bin_id=consumes(a_bin_id))
    def delete_empty_bin(self, bin_id):
        assume(self.model_bins[bin_id].contents == {})
        rp = self.client.delete(f'/api/bin/{bin_id}')
        del self.model_bins[bin_id]
        assert rp.status_code == 200
        assert rp.cache_control.no_cache

    @rule(bin_id=a_bin_id)
    def delete_nonempty_bin_noforce(self, bin_id):
        assume(self.model_bins[bin_id].contents != {})
        rp = self.client.delete(f'/api/bin/{bin_id}')
        assert rp.status_code == 405
        assert rp.is_json
        assert rp.json['type'] == 'dangerous-operation'

    @rule(bin_id=consumes(a_bin_id))
    def delete_nonempty_bin_force(self, bin_id):
        assume(self.model_bins[bin_id].contents != {})
        rp = self.client.delete(f'/api/bin/{bin_id}',
                                query_string={"force": "true"})
        del self.model_bins[bin_id]
        assert rp.status_code == 200
        assert rp.cache_control.no_cache

    @rule(bin_id=dst.label_("BIN"))
    def delete_missing_bin(self, bin_id):
        assume(bin_id not in self.model_bins.keys())
        rp = self.client.delete(f'/api/bin/{bin_id}')
        assert rp.status_code == 404
        assert rp.cache_control.no_cache
        assert rp.is_json
        assert rp.json['type'] == 'missing-resource'

    @rule(target=a_sku_id, sku=dst.skus_())
    def new_sku(self, sku):
        resp = self.client.post('/api/skus',
                                json=sku.to_dict(mask_default=True))
        if sku.id in self.model_skus.keys():
            assert resp.status_code == 409
            assert resp.is_json
            assert resp.json['type'] == 'duplicate-resource'
            return multiple()
        else:
            assert resp.status_code == 201
            self.model_skus[sku.id] = sku
            return sku.id

    @rule(sku=dst.skus_(),
          bad_code=st.sampled_from(
              ["", " ", "\t", "     ", " 123", "1 2 3", "123 abc"]))
    def new_sku_bad_format_owned_codes(self, sku, bad_code):
        assume(sku.id not in self.model_skus.keys())
        temp_sku = Sku.from_json(sku.to_json())
        temp_sku.owned_codes.append(bad_code)
        resp = self.client.post('/api/skus', json=temp_sku.to_dict())
        assert resp.status_code == 400
        assert resp.is_json
        assert resp.json['type'] == 'validation-error'

        temp_sku = Sku.from_json(sku.to_json())
        temp_sku.associated_codes.append(bad_code)
        resp = self.client.post('/api/skus', json=temp_sku.to_dict())
        assert resp.status_code == 400
        assert resp.is_json
        assert resp.json['type'] == 'validation-error'

    @rule(sku_id=a_sku_id)
    def get_existing_sku(self, sku_id):
        rp = self.client.get(f"/api/sku/{sku_id}")
        assert rp.status_code == 200
        assert rp.is_json
        found_sku = Sku(**rp.json['state'])
        assert found_sku == self.model_skus[sku_id]

    @rule(sku_id=dst.label_("SKU"))
    def get_missing_sku(self, sku_id):
        assume(sku_id not in self.model_skus.keys())
        rp = self.client.get(f"/api/sku/{sku_id}")
        assert rp.status_code == 404
        assert rp.is_json
        assert rp.json['type'] == 'missing-resource'

    sku_patch = st.builds(
        lambda sku, use_keys:
        {k: v
         for k, v in sku.__dict__.items() if k in use_keys}, dst.skus_(),
        st.sets(
            st.sampled_from(
                ["owned_codes", "associated_codes", "name", "props"])))

    @rule(sku_id=a_sku_id, patch=sku_patch)
    def update_sku(self, sku_id, patch):
        rp = self.client.patch(f'/api/sku/{sku_id}', json=patch)
        assert rp.status_code == 200
        assert rp.cache_control.no_cache
        for key in patch.keys():
            setattr(self.model_skus[sku_id], key, patch[key])

    @rule(sku_id=consumes(a_sku_id))
    def delete_unused_sku(self, sku_id):
        assume(not any([
            sku_id in bin.contents.keys() for bin in self.model_bins.values()
        ]))
        rp = self.client.delete(f"/api/sku/{sku_id}")
        assert rp.status_code == 204
        assert rp.cache_control.no_cache
        del self.model_skus[sku_id]

    @invariant()
    def positive_quantities(self):
        for bin_id, bin in self.model_bins.items():
            for item_id, quantity in bin.contents.items():
                assert quantity >= 1

    @rule(sku_id=a_sku_id)
    def sku_locations(self, sku_id):
        rp = self.client.get(f"/api/sku/{sku_id}/bins")
        assert rp.status_code == 200
        assert rp.is_json
        locations = rp.json['state']
        for bin_id, contents in locations.items():
            for item_id, quantity in contents.items():
                assert self.model_bins[bin_id].contents[item_id] == quantity
        model_locations = {}
        for bin_id, bin in self.model_bins.items():
            if sku_id in bin.contents.keys():
                model_locations[bin_id] = {
                    item_id: quantity
                    for item_id, quantity in bin.contents.items()
                    if item_id == sku_id
                }
        assert model_locations == locations

    @rule(sku_id=dst.label_("SKU"))
    def missing_sku_locations(self, sku_id):
        assume(sku_id not in self.model_skus.keys())
        rp = self.client.get(f"/api/sku/{sku_id}/bins")
        assert rp.status_code == 404
        assert rp.is_json
        assert rp.json['type'] == "missing-resource"

    @rule(sku_id=a_sku_id)
    def attempt_delete_used_sku(self, sku_id):
        assume(
            any([
                sku_id in bin.contents.keys()
                for bin in self.model_bins.values()
            ]))
        rp = self.client.delete(f"/api/sku/{sku_id}")
        assert rp.status_code == 403
        assert rp.is_json
        assert rp.json['type'] == "resource-in-use"

    @rule(sku_id=dst.label_("SKU"))
    def delete_missing_sku(self, sku_id):
        assume(sku_id not in self.model_skus.keys())
        rp = self.client.delete(f"/api/sku/{sku_id}")
        assert rp.status_code == 404
        assert rp.is_json
        assert rp.json['type'] == "missing-resource"

    @rule(target=a_batch_id, sku_id=a_sku_id, data=st.data())
    def new_batch_existing_sku(self, sku_id, data):
        # assume(self.model_skus != {})  # TODO: check if this is necessary
        batch = data.draw(dst.batches_(sku_id=sku_id))

        rp = self.client.post('/api/batches',
                              json=batch.to_dict(mask_default=True))

        if batch.id in self.model_batches.keys():
            assert rp.status_code == 409
            assert rp.json['type'] == 'duplicate-resource'
            assert rp.is_json
            return multiple()
        else:
            assert rp.status_code == 201
            self.model_batches[batch.id] = batch
            return batch.id

    @rule(data=dst.data(),
          sku_id=a_sku_id,
          bad_code=st.sampled_from(
              ["", " ", "\t", "     ", " 123", "1 2 3", "123 abc"]))
    def new_batch_bad_format_owned_codes(self, data, sku_id, bad_code):
        batch = data.draw(dst.batches_(sku_id=sku_id))
        assume(batch.id not in self.model_batches.keys())

        temp_batch = Batch.from_json(batch.to_json())
        temp_batch.owned_codes.append(bad_code)
        resp = self.client.post('/api/batches', json=temp_batch.to_dict())
        assert resp.status_code == 400
        assert resp.is_json
        assert resp.json['type'] == 'validation-error'

        temp_batch = Batch.from_json(batch.to_json())
        temp_batch.associated_codes.append(bad_code)
        resp = self.client.post('/api/batches', json=temp_batch.to_dict())
        assert resp.status_code == 400
        assert resp.is_json
        assert resp.json['type'] == 'validation-error'

    @rule(batch=dst.batches_())
    def new_batch_new_sku(self, batch):
        assume(batch.sku_id)
        assume(batch.sku_id not in self.model_skus.keys())
        rp = self.client.post("/api/batches", json=batch.to_json())

        assert rp.status_code == 409
        assert rp.is_json
        assert rp.json['type'] == "missing-resource"

    @rule(target=a_batch_id, batch=dst.batches_(sku_id=None))
    def new_anonymous_batch(self, batch):
        assert not batch.sku_id
        rp = self.client.post("/api/batches",
                              json=batch.to_dict(mask_default=True))

        if batch.id in self.model_batches.keys():
            assert rp.status_code == 409
            assert rp.json['type'] == 'duplicate-resource'
            assert rp.is_json
            return multiple()
        else:
            assert rp.json.get('type') is None
            assert rp.status_code == 201
            self.model_batches[batch.id] = batch
            return batch.id

    @rule(batch_id=a_batch_id)
    def get_existing_batch(self, batch_id):
        rp = self.client.get(f"/api/batch/{batch_id}")
        assert rp.status_code == 200
        assert rp.is_json
        found_batch = Batch.from_json(rp.json['state'])
        assert found_batch == self.model_batches[batch_id]

    @rule(batch_id=dst.label_("BAT"))
    def get_missing_batch(self, batch_id):
        assume(batch_id not in self.model_batches.keys())
        rp = self.client.get(f"/api/batch/{batch_id}")
        assert rp.status_code == 404
        assert rp.is_json
        assert rp.json['type'] == "missing-resource"

    batch_patch = st.builds(
        lambda batch, use_keys:
        {k: v
         for k, v in batch.__dict__.items() if k in use_keys}, dst.skus_(),
        st.sets(st.sampled_from(["owned_codes", "associated_codes", "props"])))

    @rule(batch_id=a_batch_id, patch=batch_patch)
    def update_batch(self, batch_id, patch):
        patch['id'] = batch_id
        rp = self.client.patch(f'/api/batch/{batch_id}', json=patch)
        assert rp.status_code == 200
        assert rp.cache_control.no_cache
        for key in patch.keys():
            setattr(self.model_batches[batch_id], key, patch[key])

    @rule(batch_id=dst.label_("BAT"), patch=batch_patch)
    def update_nonexisting_batch(self, batch_id, patch):
        patch['id'] = batch_id

        assume(batch_id not in self.model_batches.keys())
        rp = self.client.patch(f'/api/batch/{batch_id}', json=patch)
        assert rp.status_code == 404
        assert rp.is_json
        assert rp.json['type'] == "missing-resource"

    @rule(batch_id=a_batch_id, sku_id=a_sku_id, patch=batch_patch)
    def attempt_update_nonanonymous_batch_sku_id(self, batch_id, sku_id,
                                                 patch):
        patch['id'] = batch_id

        assume(self.model_batches[batch_id].sku_id)
        assume(sku_id != self.model_batches[batch_id].sku_id)
        patch['sku_id'] = sku_id
        rp = self.client.patch(f'/api/batch/{batch_id}', json=patch)
        assert rp.status_code == 405
        assert rp.is_json
        assert rp.json['type'] == "dangerous-operation"

    @rule(batch_id=a_batch_id, sku_id=a_sku_id, patch=batch_patch)
    def update_anonymous_batch_existing_sku_id(self, batch_id, sku_id, patch):
        patch['id'] = batch_id

        assume(not self.model_batches[batch_id].sku_id)
        patch['sku_id'] = sku_id
        rp = self.client.patch(f"/api/batch/{batch_id}", json=patch)
        assert rp.status_code == 200
        assert rp.cache_control.no_cache
        for key in patch.keys():
            setattr(self.model_batches[batch_id], key, patch[key])

    @rule(batch_id=a_batch_id, sku_id=dst.label_("SKU"), patch=batch_patch)
    def attempt_update_anonymous_batch_missing_sku_id(self, batch_id, sku_id,
                                                      patch):
        patch['id'] = batch_id

        assume(sku_id not in self.model_skus.keys())
        patch['sku_id'] = sku_id
        rp = self.client.patch(f"/api/batch/{batch_id}", json=patch)
        assert rp.status_code == 400
        assert rp.is_json
        assert rp.json['type'] == "validation-error"
        assert {
            'name': 'sku_id',
            'reason': 'must be an existing sku id'
        } in rp.json['invalid-params']

    @rule(batch_id=consumes(a_batch_id))
    def delete_unused_batch(self, batch_id):
        assume(not any([
            batch_id in bin.contents.keys()
            for bin in self.model_bins.values()
        ]))
        rp = self.client.delete(f"/api/batch/{batch_id}")
        del self.model_batches[batch_id]
        assert rp.status_code == 200
        assert rp.cache_control.no_cache

    @rule(batch_id=a_batch_id)
    def batch_locations(self, batch_id):
        rp = self.client.get(f"/api/batch/{batch_id}/bins")
        assert rp.status_code == 200
        assert rp.is_json
        locations = rp.json['state']

        for bin_id, contents in locations.items():
            for item_id, quantity in contents.items():
                assert self.model_bins[bin_id].contents[item_id] == quantity

        model_locations = {}
        for bin_id, bin in self.model_bins.items():
            if batch_id in bin.contents.keys():
                model_locations[bin_id] = {
                    item_id: quantity
                    for item_id, quantity in bin.contents.items()
                    if item_id == batch_id
                }
        assert model_locations == locations

    @rule(batch_id=dst.label_("BAT"))
    def nonexisting_batch_locations(self, batch_id):
        assume(batch_id not in self.model_batches.keys())
        rp = self.client.get(f"/api/batch/{batch_id}/bins")
        assert rp.status_code == 404
        assert rp.is_json
        assert rp.json['type'] == "missing-resource"

    @rule(sku_id=a_sku_id)
    def sku_batchs(self, sku_id):
        rp = self.client.get(f"/api/sku/{sku_id}/batches")
        assert rp.status_code == 200
        assert rp.is_json
        batch_ids = rp.json['state']

        model_batch_ids = [
            batch.id for batch in self.model_batches.values()
            if batch.sku_id == sku_id
        ]
        assert batch_ids == model_batch_ids

    @rule(sku_id=dst.label_("SKU"))
    def missing_sku_batches(self, sku_id):
        assume(sku_id not in self.model_skus.keys())
        rp = self.client.get(f"/api/sku/{sku_id}/batches")
        assert rp.status_code == 404
        assert rp.is_json
        assert rp.json['type'] == 'missing-resource'

    # Inventory operations

    @rule(bin_id=a_bin_id, sku_id=a_sku_id, quantity=st.integers(1, 100))
    def receive_sku(self, bin_id, sku_id, quantity):
        rp = self.client.post(f"/api/bin/{bin_id}/contents",
                              json={
                                  "id": sku_id,
                                  "quantity": quantity
                              })
        rp.status_code == 201
        self.model_bins[bin_id].contents[sku_id] \
            = self.model_bins[bin_id].contents.get(sku_id, 0) + quantity

    @rule(bin_id=dst.label_("BIN"),
          sku_id=dst.label_("SKU"),
          quantity=st.integers(1, 100))
    def receive_missing_sku_bin(self, bin_id, sku_id, quantity):
        rp = self.client.post(f"/api/bin/{bin_id}/contents",
                              json={
                                  "id": sku_id,
                                  "quantity": quantity
                              })
        if bin_id not in self.model_bins.keys():
            assert rp.status_code == 404
            assert rp.is_json
            assert rp.json['type'] == 'missing-resource'
        elif sku_id not in self.model_skus.keys():
            assert rp.status_code == 409
            assert rp.is_json
            assert rp.json['type'] == 'missing-resource'

    @rule(bin_id=a_bin_id, batch_id=a_batch_id, quantity=st.integers(1, 100))
    def receive_batch(self, bin_id, batch_id, quantity):
        rp = self.client.post(f"/api/bin/{bin_id}/contents",
                              json={
                                  "id": batch_id,
                                  "quantity": quantity
                              })
        rp.status_code == 201
        self.model_bins[bin_id].contents[batch_id] \
            = self.model_bins[bin_id].contents.get(batch_id, 0) + quantity

    @rule(bin_id=a_bin_id, batch_id=a_batch_id, quantity=st.integers(1, 100))
    def receive_missing_batch_bin(self, bin_id, batch_id, quantity):
        assume(bin_id not in self.model_bins.keys()
               or batch_id not in self.model_batches.keys())
        rp = self.client.post(f"/api/bin/{bin_id}/contents",
                              json={
                                  "id": batch_id,
                                  "quantity": quantity
                              })
        if bin_id not in self.model_bins.keys():
            assert rp.status_code == 404
            assert rp.is_json
            assert rp.json['type'] == 'missing-resource'
        elif batch_id not in self.model_batches.keys():
            assert rp.status_code == 409
            assert rp.is_json
            assert rp.json['type'] == 'missing-resource'

    @rule(source_binId=a_bin_id, destination_binId=a_bin_id, data=st.data())
    def move(self, source_binId, destination_binId, data):
        assume(source_binId != destination_binId)
        # assume(sku_id in self.model_bins[source_binId].contents.keys())
        assume(self.model_bins[source_binId].contents != {})
        sku_id = data.draw(
            st.sampled_from(list(
                self.model_bins[source_binId].contents.keys())))
        # assume(quantity >= self.model_bins[source_binId].contents[sku_id])
        quantity = data.draw(
            st.integers(1, self.model_bins[source_binId].contents[sku_id]))
        rp = self.client.put(f'/api/bin/{source_binId}/contents/move',
                             json={
                                 "id": sku_id,
                                 "quantity": quantity,
                                 "destination": destination_binId
                             })
        assert rp.status_code == 204
        assert rp.cache_control.no_cache

        self.model_bins[source_binId].contents[sku_id] -= quantity
        self.model_bins[destination_binId].contents[sku_id] = quantity \
            + self.model_bins[destination_binId].contents.get(sku_id, 0)
        if self.model_bins[source_binId].contents[sku_id] == 0:
            del self.model_bins[source_binId].contents[sku_id]

    @rule()
    def api_next(self):
        rp = self.client.get("/api/next/bin")
        assert rp.status_code == 200
        assert rp.is_json
        next_bin = rp.json['state']
        assert next_bin not in self.model_bins.keys()
        assert next_bin.startswith("BIN")
        assert len(next_bin) == 9

        rp = self.client.get("/api/next/sku")
        assert rp.status_code == 200
        assert rp.is_json
        next_sku = rp.json['state']
        assert next_sku not in self.model_skus.keys()
        assert next_sku.startswith("SKU")
        assert len(next_sku) == 9

        rp = self.client.get("/api/next/batch")
        assert rp.status_code == 200
        assert rp.is_json
        next_batch = rp.json['state']
        assert next_batch not in self.model_bins.keys()
        assert next_batch.startswith("BAT")
        assert len(next_batch) == 9

    def search_results_generator(self, query):
        def json_to_data_model(in_json_dict):
            if in_json_dict['id'].startswith("BIN"):
                return Bin.from_json(in_json_dict)
            if in_json_dict['id'].startswith("SKU"):
                return Sku.from_json(in_json_dict)
            if in_json_dict['id'].startswith("BAT"):
                return Batch.from_json(in_json_dict)

        starting_from = 0
        while True:
            rp = self.client.get("/api/search",
                                 query_string={
                                     "query": query,
                                     "startingFrom": starting_from,
                                 })
            assert rp.status_code == 200
            assert rp.is_json
            search_state = rp.json['state']
            for result_json in search_state['results']:
                yield json_to_data_model(result_json)
            if search_state['starting_from'] + search_state[
                    'limit'] > search_state['total_num_results']:
                break
            else:
                starting_from += search_state['limit']

    def search_query_matches(self, query, unit):
        STOP_WORDS = "a and the".split()
        terms = query.split()
        if len(terms) != 1:
            pass
        elif query == "" or query in STOP_WORDS:
            return False
        elif query == unit.id:
            return True
        elif hasattr(unit, "owned_codes") and query in unit.owned_codes:
            return True
        elif hasattr(unit,
                     "associate_codes") and query in unit.associated_codes:
            return True
        elif hasattr(unit,
                     "name") and query.casefold() in unit.name.casefold():
            return True
        return False

    # @rule(query=dst.search_query)
    # def search(self, query):
    #     results = list(self.search_results_generator(query))
    #     for unit in it.chain(self.model_bins.values(), self.model_skus.values(), self.model_batches.values()):
    #         if self.search_query_matches(query, unit):
    #             assert unit in results
    #         else:
    #             assert unit not in results

    @rule()
    def search_no_query(self):
        results = list(self.search_results_generator(""))
        assert results == []

    # Safety Invariants

    @invariant()
    def batches_skus_with_same_sku_never_share_bin(self):
        return  # TODO: remove this skipped check
        for bin in self.model_bins.values():
            sku_types = []
            for item_id in bin.contents.values():
                if item_id.startswith("BAT"):
                    sku_type = self.model_batches[item_id].sku_id
                elif item_id.startswith("SKU"):
                    sku_type = item_id
                else:
                    assert False  # bin contents must be Batches or Skus
                if sku_type:
                    assert sku_type not in sku_types  # each sku type should only appear once
                    sku_types.append(sku_type)
class MediatorMixin:
    def __init__(self):
        super().__init__()
        self.partner_to_balance_proof_data = dict()
        self.secrethash_to_secret = dict()
        self.waiting_for_unlock = dict()
        self.initial_number_of_channels = 2

    def _get_balance_proof_data(self, partner):
        if partner not in self.partner_to_balance_proof_data:
            partner_channel = self.address_to_channel[partner]
            self.partner_to_balance_proof_data[partner] = BalanceProofData(
                canonical_identifier=partner_channel.canonical_identifier)
        return self.partner_to_balance_proof_data[partner]

    def _update_balance_proof_data(self, partner, amount, expiration, secret):
        expected = self._get_balance_proof_data(partner)
        lock = HashTimeLockState(amount=amount,
                                 expiration=expiration,
                                 secrethash=sha256_secrethash(secret))
        expected.update(amount, lock)
        return expected

    init_mediators = Bundle("init_mediators")
    secret_requests = Bundle("secret_requests")
    unlocks = Bundle("unlocks")

    def _new_mediator_transfer(self, initiator_address, target_address,
                               payment_id, amount,
                               secret) -> LockedTransferSignedState:
        initiator_pkey = self.address_to_privkey[initiator_address]
        balance_proof_data = self._update_balance_proof_data(
            initiator_address, amount, self.block_number + 10, secret)
        self.secrethash_to_secret[sha256_secrethash(secret)] = secret

        return factories.create(
            factories.LockedTransferSignedStateProperties(
                **balance_proof_data.properties.__dict__,
                amount=amount,
                expiration=self.block_number + 10,
                payment_identifier=payment_id,
                secret=secret,
                initiator=initiator_address,
                target=target_address,
                token=self.token_id,
                sender=initiator_address,
                recipient=self.address,
                pkey=initiator_pkey,
                message_identifier=1,
            ))

    def _action_init_mediator(
            self, transfer: LockedTransferSignedState) -> ActionInitMediator:
        initiator_channel = self.address_to_channel[transfer.initiator]
        target_channel = self.address_to_channel[transfer.target]

        return ActionInitMediator(
            route_states=[factories.make_route_from_channel(target_channel)],
            from_hop=factories.make_hop_to_channel(initiator_channel),
            from_transfer=transfer,
            balance_proof=transfer.balance_proof,
            sender=transfer.balance_proof.sender,
        )

    @rule(
        target=init_mediators,
        initiator_address=partners,
        target_address=partners,
        payment_id=payment_id(),  # pylint: disable=no-value-for-parameter
        amount=integers(min_value=1, max_value=100),
        secret=secret(),  # pylint: disable=no-value-for-parameter
    )
    def valid_init_mediator(self, initiator_address, target_address,
                            payment_id, amount, secret):
        assume(initiator_address != target_address)

        transfer = self._new_mediator_transfer(initiator_address,
                                               target_address, payment_id,
                                               amount, secret)
        action = self._action_init_mediator(transfer)
        result = node.state_transition(self.chain_state, action)

        assert event_types_match(result.events, SendProcessed,
                                 SendLockedTransfer)

        return action

    @rule(target=secret_requests, previous_action=consumes(init_mediators))
    def valid_receive_secret_reveal(self, previous_action):
        secret = self.secrethash_to_secret[
            previous_action.from_transfer.lock.secrethash]
        sender = previous_action.from_transfer.target
        recipient = previous_action.from_transfer.initiator

        action = ReceiveSecretReveal(secret=secret, sender=sender)
        result = node.state_transition(self.chain_state, action)

        expiration = previous_action.from_transfer.lock.expiration
        in_time = self.block_number < expiration - DEFAULT_NUMBER_OF_BLOCK_CONFIRMATIONS
        still_waiting = self.block_number < expiration + DEFAULT_WAIT_BEFORE_LOCK_REMOVAL

        if in_time and self.channel_opened(sender) and self.channel_opened(
                recipient):
            assert event_types_match(result.events, SendSecretReveal,
                                     SendBalanceProof, EventUnlockSuccess)
            self.event("Unlock successful.")
            self.waiting_for_unlock[secret] = recipient
        elif still_waiting and self.channel_opened(recipient):
            assert event_types_match(result.events, SendSecretReveal)
            self.event("Unlock failed, secret revealed too late.")
        else:
            assert not result.events
            self.event(
                "ReceiveSecretRevealed after removal of lock - dropped.")
        return action

    @rule(previous_action=secret_requests)
    def replay_receive_secret_reveal(self, previous_action):
        result = node.state_transition(self.chain_state, previous_action)
        assert not result.events

    # pylint: disable=no-value-for-parameter
    @rule(previous_action=secret_requests, invalid_sender=address())
    # pylint: enable=no-value-for-parameter
    def replay_receive_secret_reveal_scrambled_sender(self, previous_action,
                                                      invalid_sender):
        action = ReceiveSecretReveal(previous_action.secret, invalid_sender)
        result = node.state_transition(self.chain_state, action)
        assert not result.events

    # pylint: disable=no-value-for-parameter
    @rule(previous_action=init_mediators, secret=secret())
    # pylint: enable=no-value-for-parameter
    def wrong_secret_receive_secret_reveal(self, previous_action, secret):
        sender = previous_action.from_transfer.target
        action = ReceiveSecretReveal(secret, sender)
        result = node.state_transition(self.chain_state, action)
        assert not result.events

    # pylint: disable=no-value-for-parameter
    @rule(target=secret_requests,
          previous_action=consumes(init_mediators),
          invalid_sender=address())
    # pylint: enable=no-value-for-parameter
    def wrong_address_receive_secret_reveal(self, previous_action,
                                            invalid_sender):
        secret = self.secrethash_to_secret[
            previous_action.from_transfer.lock.secrethash]
        invalid_action = ReceiveSecretReveal(secret, invalid_sender)
        result = node.state_transition(self.chain_state, invalid_action)
        assert not result.events

        valid_sender = previous_action.from_transfer.target
        valid_action = ReceiveSecretReveal(secret, valid_sender)
        return valid_action
class StatefulPytestDjangoModelGenerator(RuleBasedStateMachine):
    name = Bundle("name")
    constants = Bundle("constants")
    fields = Bundle("fields")
    meta = Bundle("meta")

    @rule(target=name, name=fake_class_name())
    def add_name(self, name):
        assume(not model_exists(name))

        return name

    @rule(target=constants, constants=fake_constants())
    def add_constants(self, constants):
        return constants

    @rule(target=fields, fields=fake_fields_data())
    def add_fields(self, fields):
        return fields

    @rule(target=meta, meta=default_meta())
    def add_meta(self, meta):
        return meta

    @rule(name=consumes(name), constants=constants, fields=fields, meta=meta)
    def test_model_generator(self, name, constants, fields, meta):
        # Remove Duplicates Fields.
        for field in fields:
            constants.pop(field, None)

        try:
            django_model = get_django_model(name=name,
                                            constants=constants,
                                            fields=fields,
                                            meta=meta)
            model_object = get_model_object(django_model)
        except Exception as e:
            pytest.fail(e)

        # Test Constants
        assert hasattr(model_object._meta, "constants")
        for constant_name, value in constants.items():
            constant_attribute_object = AttributeObject(name=constant_name,
                                                        value=value,
                                                        parents=name)
            assert constant_name in model_object._meta.constants
            assert hasattr(model_object, constant_name)
            assert getattr(model_object,
                           constant_name) == constant_attribute_object

        # Test Fields
        assert hasattr(model_object._meta, "fields")
        for field_name, data in fields.items():
            # Prepare object to test.
            field_attrs = data["class"](**data["attrs"]).deconstruct()[3]
            field_attribute_object = AttributeObject(name=field_name,
                                                     cls=data["class"],
                                                     value=field_attrs,
                                                     parents=name)
            assert field_name in model_object._meta.fields
            assert hasattr(model_object, field_name)
            assert getattr(model_object, field_name) == field_attribute_object

        # Test Meta
        assert hasattr(model_object, "Meta")
        assert hasattr(model_object._meta, "meta")
        for option_name, value in meta.items():
            meta_attribute_object = AttributeObject(name=option_name,
                                                    value=value,
                                                    parents=[name, "Meta"])
            assert option_name in model_object._meta.meta
            assert hasattr(model_object.Meta, option_name)
            assert getattr(model_object.Meta,
                           option_name) == meta_attribute_object
Example #22
0
class StatefulTestPytestDjangoModel(RuleBasedStateMachine):
    # Basic Components
    ##################
    name = Bundle("name")

    @rule(target=name, name=fake_class_name())
    def add_name(self, name):
        assume(not model_exists(name))

        return name

    # Intermediate Components
    #########################
    model_data = Bundle("model_data")
    parent = Bundle("parent")

    @rule(
        target=model_data,
        name=consumes(name),
        constants=fake_constants(),
        fields=fake_fields_data(),
        meta=default_meta(),
    )
    def add_model(self, name, constants, fields, meta):
        # Remove Duplicates Fields
        for field in fields:
            constants.pop(field, None)

        return Kwargs(name=name, constants=constants, fields=fields, meta=meta)

    @rule(
        target=parent,
        name=consumes(name),
        constants=fake_constants(),
        fields=fake_fields_data(min_size=1),
        meta=default_meta(),
        pk_name=fake_attr_name(),
    )
    def add_parent(self, name, constants, fields, meta, pk_name):
        try:
            pk = {
                pk_name: {
                    "class": CharField,
                    "attrs": {"primary_key": True, "max_length": 256},
                }
            }
            parent = get_django_model(
                name=name, constants=constants, fields={**pk, **fields}, meta=meta
            )
        except Exception as e:
            pytest.fail(e)
        else:
            parent.fields = [field.name for field in get_model_fields(parent)]

            return parent

    # PytestDjangoModel Components
    ######################
    data = Bundle("data")

    @rule(
        target=data,
        original=consumes(model_data),
        tester=consumes(model_data),
        parents=st.lists(elements=parent, max_size=3, unique=True).filter(
            lambda parents: not have_similarities(
                *[parent.fields for parent in parents]
            )
        ),
    )
    def add_test_model_data(self, original, tester, parents):
        # Remove Duplicates Fields with Parents.
        parents_fields = [field for parent in parents for field in parent.fields]
        for field in parents_fields:
            original.fields.pop(field, None)
            tester.fields.pop(field, None)

        # Prepare Original Objects.
        try:
            original_django_model = get_django_model(
                name=original.name,
                constants=original.constants,
                fields=original.fields,
                meta=original.meta,
                parents=tuple(parents),
            )
            original_model_object = get_model_object(original_django_model)
            original = Kwargs(
                data=original,
                django_model=original_django_model,
                model_object=original_model_object,
            )
        except Exception as e:
            pytest.fail(e)

        # Prepare Test Model dct.
        dct = {
            **tester.constants,
            **get_fields(tester.fields),
            "Meta": get_meta_class(**tester.meta),
        }

        name = tester.name
        dct["Meta"].model = original.django_model
        dct["Meta"].parents = parents

        return Kwargs(
            original=original, tester=tester, parents=parents, name=name, dct=dct
        )

    # Tests
    #######
    dirty_fields = Bundle("dirty_fields")

    @rule(
        target=dirty_fields,
        cls_name=consumes(name),
        fields=fake_fields_data(dirty=True, min_size=1),
    )
    def add_dirty_fields(self, cls_name, fields):
        django_model = get_django_model(cls_name, constants={}, fields=fields, meta={})
        assume(django_model.check())

        return get_fields(fields)

    @rule(data=consumes(data))
    def assert_no_error(self, data):
        original, tester, parents = data.original, data.tester, data.parents
        name, bases, dct = data.tester.name, (), data.dct

        name = tester.name
        dct["Meta"].model = original.django_model
        dct["Meta"].parents = parents

        try:
            test_model = PytestDjangoModel(name, bases, dct)
        except Exception as e:
            pytest.fail(e)

        assert not model_exists(name)

    @rule(data=consumes(data))
    def assert_no_meta(self, data):
        original, tester, parents = data.original, data.tester, data.parents
        name, bases, dct = data.tester.name, (), data.dct

        dct.pop("Meta")

        error_msg = f"{name} must have a 'Meta' inner class with 'model' attribute."
        with pytest.raises(ModelNotFoundError, match=error_msg):
            PytestDjangoModel(name, bases, dct)

        assert not model_exists(name)

    @rule(data=consumes(data), invalid_class_name=consumes(name))
    def assert_original_model_not_found(self, data, invalid_class_name):
        original, tester, parents = data.original, data.tester, data.parents
        name, bases, dct = data.tester.name, (), data.dct

        delattr(dct["Meta"], "model")

        error_msg = f"'Meta' inner class has not 'model' attribute."
        with pytest.raises(ModelNotFoundError, match=error_msg):
            PytestDjangoModel(name, bases, dct)

        assert not model_exists(name)

    @rule(
        data=consumes(data),
        invalid_class_name=consumes(name),
        invalid_type=st.one_of(*BASIC_TYPES).filter(lambda x: not isinstance(x, float)),
        isclass=st.booleans(),
    )
    def assert_original_is_invalid_model(
        self, data, invalid_class_name, invalid_type, isclass
    ):
        original, tester, parents = data.original, data.tester, data.parents
        name, bases, dct = data.tester.name, (), data.dct

        invalid_class = type(invalid_class_name, (), {}) if isclass else invalid_type
        dct["Meta"].model = invalid_class

        error_msg = get_invalid_model_msg(invalid_class)
        with pytest.raises(InvalidModelError, match=error_msg):
            PytestDjangoModel(name, bases, dct)

        assert not model_exists(name)

    @rule(
        data=consumes(data),
        invalid_class_name=consumes(name),
        isiterable=st.booleans(),
        invalid_type=st.one_of(*BASIC_TYPES).filter(lambda x: not isinstance(x, float)),
        isclass=st.booleans(),
    )
    def assert_parents_invalid_model(
        self, data, invalid_class_name, invalid_type, isclass, isiterable
    ):
        original, tester, parents = data.original, data.tester, data.parents
        name, bases, dct = data.tester.name, (), data.dct

        invalid_class = type(invalid_class_name, (), {}) if isclass else invalid_type
        error_msg = get_invalid_model_msg(invalid_class)

        if isiterable:
            event("assert_parents_invalid_model: parent is an iterable.")
            dct["Meta"].parents.append(invalid_class)
            error_msg = f"'parents' contains invalid model: {error_msg}"
        else:
            event("assert_parents_invalid_model: parent isn't an iterable.")
            dct["Meta"].parents = invalid_class
            error_msg = f"'parents': {error_msg}"

        with pytest.raises(InvalidModelError, match=error_msg):
            PytestDjangoModel(name, bases, dct)

        assert not model_exists(name)

    @rule(data=consumes(data), dirty_fields=dirty_fields)
    def assert_is_invalid_model(self, data, dirty_fields):
        original, tester, parents = data.original, data.tester, data.parents
        name, bases, dct = data.tester.name, (), data.dct

        # Assume any Dirty Field in Parents Fields.
        parents_fields = [field for parent in parents for field in parent.fields]
        assume(not have_similarities(dirty_fields, parents_fields))

        dct = {**dct, **dirty_fields}

        error_msg = (
            fr"^The {name} Model get the following errors during validation:(.|\s)+"
        )
        with pytest.raises(InvalidModelError, match=error_msg):
            PytestDjangoModel(name, bases, dct)

        if model_exists(name):
            delete_django_model(APP_LABEL, name)
        assert not model_exists(name)
Example #23
0
class HypothesisStateMachine(RuleBasedStateMachine):
    def __init__(self):
        super().__init__()
        self.chart = SseqChart("test")
        self.num_classes = 0
        self.num_edges = 0

    classes = Bundle("classes")
    structline_bdl = Bundle("structlines")
    differential_bdl = Bundle("differentials")
    extension_bdl = Bundle("extensions")
    structlines = structline_bdl.filter(lambda x: not x._deleted)
    differentials = differential_bdl.filter(lambda x: not x._deleted)
    extensions = extension_bdl.filter(lambda x: not x._deleted)

    edges = st.one_of(structlines, differentials, extensions)
    chart_objects = st.one_of(edges, classes)
    edge_and_range = st.one_of(
        st.tuples(structlines, slices()),
        st.tuples(st.one_of(differentials, extensions), st.none()),
    )
    obj_and_range = st.one_of(st.tuples(classes, slices()), edge_and_range)

    @rule(target=classes, k=st.tuples(integers, integers))
    def add_class(self, k):
        self.num_classes += 1
        return self.chart.add_class(*k)

    @rule(target=structline_bdl, c1=classes, c2=classes)
    def add_structline(self, c1, c2):
        self.num_edges += 1
        return self.chart.add_structline(c1, c2)

    @rule(target=extension_bdl, c1=classes, c2=classes)
    def add_extension(self, c1, c2):
        self.num_edges += 1
        return self.chart.add_extension(c1, c2)

    @rule(
        target=differential_bdl,
        page=integers,
        c1=classes,
        c2=classes,
        auto=st.booleans(),
    )
    def add_differential(self, page, c1, c2, auto):
        self.num_edges += 1
        return self.chart.add_differential(page, c1, c2, auto)

    @rule(o=classes,
          prop=st.sampled_from(["name", "group_name"]),
          val=st.text())
    def set_class_name(self, o, prop, val):
        setattr(o, prop, val)

    @rule(
        o=classes,
        prop=st.sampled_from(
            ["background_color", "border_color", "foreground_color"]),
        page_range=slices(),
        val=colors_strategy,
    )
    def set_class_color(self, o, prop, page_range, val):
        if page_range:
            getattr(o, prop)[page_range] = val
        else:
            setattr(o, prop, val)

    @rule(
        o=classes,
        prop=st.sampled_from(["border_width", "scale", "x_nudge", "y_nudge"]),
        page_range=slices(),
        val=integers,
    )
    def set_class_number(self, o, prop, page_range, val):
        if page_range:
            getattr(o, prop)[page_range] = val
        else:
            setattr(o, prop, val)

    @rule(obj_and_range=obj_and_range, val=st.booleans())
    def set_visible(self, obj_and_range, val):
        [o, page_range] = obj_and_range
        prop = "visible"
        if page_range:
            getattr(o, prop)[page_range] = val
        else:
            setattr(o, prop, val)

    @rule(edge_and_range=edge_and_range, val=colors_strategy)
    def set_edge_color(self, edge_and_range, val):
        [o, page_range] = edge_and_range
        prop = "color"
        if page_range:
            getattr(o, prop)[page_range] = val
        else:
            setattr(o, prop, val)

    @rule(edge_and_range=edge_and_range, val=st.lists(integers))
    def set_edge_dash_pattern(self, edge_and_range, val):
        [o, page_range] = edge_and_range
        prop = "dash_pattern"
        if page_range:
            getattr(o, prop)[page_range] = val
        else:
            setattr(o, prop, val)

    @rule(
        edge_and_range=edge_and_range,
        prop=st.sampled_from(["line_width", "bend"]),
        val=integers,
    )
    def set_edge_number(self, edge_and_range, prop, val):
        [o, page_range] = edge_and_range
        if page_range:
            getattr(o, prop)[page_range] = val
        else:
            setattr(o, prop, val)

    @rule()
    def check_num_classes(self):
        assert self.num_classes == len(self.chart.classes)

    @rule()
    def check_num_edges(self):
        assert self.num_edges == len(self.chart.edges)

    @rule()
    def double_serialize(self):
        s1 = JSON.stringify(self.chart)
        s2 = JSON.stringify(JSON.parse(s1))
        assert json.loads(s1) == json.loads(s2)

    @rule(o=st.one_of(
        consumes(classes),
        consumes(extension_bdl).filter(lambda x: not x._deleted),
        consumes(structline_bdl).filter(lambda x: not x._deleted),
        consumes(differential_bdl).filter(lambda x: not x._deleted),
    ))
    def delete_object(self, o):
        if isinstance(o, ChartClass):
            self.num_classes -= 1
            self.num_edges -= len(set(o.edges))
        else:
            self.num_edges -= 1
        o.delete()