class TestElections(unittest.TestCase):
    @settings(
        deadline=timedelta(milliseconds=2000),
        suppress_health_check=[HealthCheck.too_slow],
        max_examples=10,
    )
    @given(election_descriptions())
    def test_generators_yield_valid_output(self, ed: ElectionDescription):
        """
        Tests that our Hypothesis election strategies generate "valid" output, also exercises the full stack
        of `is_valid` methods.
        """

        self.assertTrue(ed.is_valid())

    @settings(
        deadline=timedelta(milliseconds=10000),
        suppress_health_check=[HealthCheck.too_slow],
        max_examples=5,
        # disabling the "shrink" phase, because it runs very slowly
        phases=[Phase.explicit, Phase.reuse, Phase.generate, Phase.target],
    )
    @given(
        integers(1, 3).flatmap(lambda n: elections_and_ballots(n)),
        elements_mod_q(),
    )
    def test_accumulation_encryption_decryption(
        self,
        everything: ELECTIONS_AND_BALLOTS_TUPLE_TYPE,
        nonce: ElementModQ,
    ):
        """
        Tests that decryption is the inverse of encryption over arbitrarily generated elections and ballots.

        This test uses an abitrarily generated dataset with a single public-private keypair for the election
        encryption context.  It also manually verifies that homomorphic accumulation works as expected.
        """
        # Arrange
        election_description, metadata, ballots, secret_key, context = everything

        # Tally the plaintext ballots for comparison later
        plaintext_tallies = accumulate_plaintext_ballots(ballots)
        num_ballots = len(ballots)
        num_contests = len(metadata.contests)
        zero_nonce, *nonces = Nonces(nonce)[:num_ballots + 1]
        self.assertEqual(len(nonces), num_ballots)
        self.assertTrue(len(metadata.contests) > 0)

        # Generatea valid encryption of zero
        encrypted_zero = elgamal_encrypt(0, zero_nonce,
                                         context.elgamal_public_key)

        # Act
        encrypted_ballots = []

        # encrypt each ballot
        for i in range(num_ballots):
            encrypted_ballot = encrypt_ballot(ballots[i], metadata, context,
                                              SEED_HASH, nonces[i])
            encrypted_ballots.append(encrypted_ballot)

            # sanity check the encryption
            self.assertIsNotNone(encrypted_ballot)
            self.assertEqual(num_contests, len(encrypted_ballot.contests))

            # decrypt the ballot with secret and verify it matches the plaintext
            decrypted_ballot = decrypt_ballot_with_secret(
                ballot=encrypted_ballot,
                election_metadata=metadata,
                crypto_extended_base_hash=context.crypto_extended_base_hash,
                public_key=context.elgamal_public_key,
                secret_key=secret_key,
                remove_placeholders=True,
            )
            self.assertEqual(ballots[i], decrypted_ballot)

        # homomorphically accumualte the encrypted ballot representations
        encrypted_tallies = _accumulate_encrypted_ballots(
            encrypted_zero, encrypted_ballots)

        decrypted_tallies = {}
        for object_id in encrypted_tallies.keys():
            decrypted_tallies[object_id] = encrypted_tallies[
                object_id].decrypt(secret_key)

        # loop through the contest descriptions and verify
        # the decrypted tallies match the plaintext tallies
        for contest in metadata.contests:
            # Sanity check the generated data
            self.assertTrue(len(contest.ballot_selections) > 0)
            self.assertTrue(len(contest.placeholder_selections) > 0)

            decrypted_selection_tallies = [
                decrypted_tallies[selection.object_id]
                for selection in contest.ballot_selections
            ]
            decrypted_placeholder_tallies = [
                decrypted_tallies[placeholder.object_id]
                for placeholder in contest.placeholder_selections
            ]
            plaintext_tally_values = [
                plaintext_tallies[selection.object_id]
                for selection in contest.ballot_selections
            ]

            # verify the plaintext tallies match the decrypted tallies
            self.assertEqual(decrypted_selection_tallies,
                             plaintext_tally_values)

            # validate the right number of selections including placeholders across all ballots
            self.assertEqual(
                contest.number_elected * num_ballots,
                sum(decrypted_selection_tallies) +
                sum(decrypted_placeholder_tallies),
            )
Esempio n. 2
0
class TestTally(TestCase):
    """Tally tests"""
    @settings(
        deadline=timedelta(milliseconds=10000),
        suppress_health_check=[HealthCheck.too_slow],
        max_examples=3,
        # disabling the "shrink" phase, because it runs very slowly
        phases=[Phase.explicit, Phase.reuse, Phase.generate, Phase.target],
    )
    @given(integers(2, 5).flatmap(lambda n: elections_and_ballots(n)))
    def test_tally_cast_ballots_accumulates_valid_tally(
            self, everything: ELECTIONS_AND_BALLOTS_TUPLE_TYPE):
        # Arrange
        (
            _election_description,
            internal_manifest,
            ballots,
            secret_key,
            context,
        ) = everything
        # Tally the plaintext ballots for comparison later
        plaintext_tallies = accumulate_plaintext_ballots(ballots)

        # encrypt each ballot
        store = DataStore()
        encryption_seed = ElectionFactory.get_encryption_device().get_hash()
        for ballot in ballots:
            encrypted_ballot = encrypt_ballot(ballot, internal_manifest,
                                              context, encryption_seed)
            encryption_seed = encrypted_ballot.code
            self.assertIsNotNone(encrypted_ballot)
            # add to the ballot store
            store.set(
                encrypted_ballot.object_id,
                from_ciphertext_ballot(encrypted_ballot, BallotBoxState.CAST),
            )

        # act
        result = tally_ballots(store, internal_manifest, context)
        self.assertIsNotNone(result)

        # Assert
        decrypted_tallies = self._decrypt_with_secret(result, secret_key)
        self.assertEqual(plaintext_tallies, decrypted_tallies)

    @settings(
        deadline=timedelta(milliseconds=10000),
        suppress_health_check=[HealthCheck.too_slow],
        max_examples=3,
        # disabling the "shrink" phase, because it runs very slowly
        phases=[Phase.explicit, Phase.reuse, Phase.generate, Phase.target],
    )
    @given(integers(1, 3).flatmap(lambda n: elections_and_ballots(n)))
    def test_tally_spoiled_ballots_accumulates_valid_tally(
            self, everything: ELECTIONS_AND_BALLOTS_TUPLE_TYPE):
        # Arrange
        (
            _election_description,
            internal_manifest,
            ballots,
            secret_key,
            context,
        ) = everything
        # Tally the plaintext ballots for comparison later
        plaintext_tallies = accumulate_plaintext_ballots(ballots)

        # encrypt each ballot
        store = DataStore()
        encryption_seed = ElectionFactory.get_encryption_device().get_hash()
        for ballot in ballots:
            encrypted_ballot = encrypt_ballot(ballot, internal_manifest,
                                              context, encryption_seed)
            encryption_seed = encrypted_ballot.code
            self.assertIsNotNone(encrypted_ballot)
            # add to the ballot store
            store.set(
                encrypted_ballot.object_id,
                from_ciphertext_ballot(encrypted_ballot,
                                       BallotBoxState.SPOILED),
            )

        # act
        tally = tally_ballots(store, internal_manifest, context)
        self.assertIsNotNone(tally)

        # Assert
        decrypted_tallies = self._decrypt_with_secret(tally, secret_key)
        self.assertCountEqual(plaintext_tallies, decrypted_tallies)
        for value in decrypted_tallies.values():
            self.assertEqual(0, value)
        self.assertEqual(len(ballots), tally.spoiled())

    @settings(
        deadline=timedelta(milliseconds=10000),
        suppress_health_check=[HealthCheck.too_slow],
        max_examples=3,
        # disabling the "shrink" phase, because it runs very slowly
        phases=[Phase.explicit, Phase.reuse, Phase.generate, Phase.target],
    )
    @given(integers(1, 3).flatmap(lambda n: elections_and_ballots(n)))
    def test_tally_ballot_invalid_input_fails(
            self, everything: ELECTIONS_AND_BALLOTS_TUPLE_TYPE):

        # Arrange
        (
            _election_description,
            internal_manifest,
            ballots,
            _secret_key,
            context,
        ) = everything

        # encrypt each ballot
        store = DataStore()
        encryption_seed = ElectionFactory.get_encryption_device().get_hash()
        for ballot in ballots:
            encrypted_ballot = encrypt_ballot(ballot, internal_manifest,
                                              context, encryption_seed)
            encryption_seed = encrypted_ballot.code
            self.assertIsNotNone(encrypted_ballot)
            # add to the ballot store
            store.set(
                encrypted_ballot.object_id,
                from_ciphertext_ballot(encrypted_ballot, BallotBoxState.CAST),
            )

        tally = CiphertextTally("my-tally", internal_manifest, context)

        # act
        cached_ballots = store.all()
        first_ballot = cached_ballots[0]
        first_ballot.state = BallotBoxState.UNKNOWN

        # verify an UNKNOWN state ballot fails
        self.assertIsNone(tally_ballot(first_ballot, tally))
        self.assertFalse(tally.append(first_ballot))

        # cast a ballot
        first_ballot.state = BallotBoxState.CAST
        self.assertTrue(tally.append(first_ballot))

        # try to append a spoiled ballot
        first_ballot.state = BallotBoxState.SPOILED
        self.assertFalse(tally.append(first_ballot))

        # Verify accumulation fails if the selection collection is empty
        if first_ballot.state == BallotBoxState.CAST:
            self.assertFalse(
                tally.contests[first_ballot.object_id].accumulate_contest([]))

        # pylint: disable=protected-access
        # pop the cast ballot
        tally._cast_ballot_ids.pop()

        # reset to cast
        first_ballot.state = BallotBoxState.CAST

        self.assertTrue(
            self._cannot_erroneously_mutate_state(tally, first_ballot,
                                                  BallotBoxState.CAST))

        self.assertTrue(
            self._cannot_erroneously_mutate_state(tally, first_ballot,
                                                  BallotBoxState.SPOILED))

        self.assertTrue(
            self._cannot_erroneously_mutate_state(tally, first_ballot,
                                                  BallotBoxState.UNKNOWN))

        # verify a cast ballot cannot be added twice
        first_ballot.state = BallotBoxState.CAST
        self.assertTrue(tally.append(first_ballot))
        self.assertFalse(tally.append(first_ballot))

        # verify an already submitted ballot cannot be changed or readded
        first_ballot.state = BallotBoxState.SPOILED
        self.assertFalse(tally.append(first_ballot))

    @staticmethod
    def _decrypt_with_secret(tally: CiphertextTally,
                             secret_key: ElementModQ) -> Dict[str, int]:
        """
        Demonstrates how to decrypt a tally with a known secret key
        """
        plaintext_selections: Dict[str, int] = {}
        for _, contest in tally.contests.items():
            for object_id, selection in contest.selections.items():
                plaintext_tally = selection.ciphertext.decrypt(secret_key)
                plaintext_selections[object_id] = plaintext_tally

        return plaintext_selections

    def _cannot_erroneously_mutate_state(
        self,
        tally: CiphertextTally,
        ballot: SubmittedBallot,
        state_to_test: BallotBoxState,
    ) -> bool:

        input_state = ballot.state
        ballot.state = state_to_test

        # remove the first selection
        first_contest = ballot.contests[0]
        first_selection = first_contest.ballot_selections[0]
        ballot.contests[0].ballot_selections.remove(first_selection)

        self.assertIsNone(tally_ballot(ballot, tally))
        self.assertFalse(tally.append(ballot))

        # Verify accumulation fails if the selection count does not match
        if ballot.state == BallotBoxState.CAST:
            first_tally = tally.contests[first_contest.object_id]
            self.assertFalse(
                first_tally.accumulate_contest(
                    ballot.contests[0].ballot_selections))

            # pylint: disable=protected-access
            _key, bad_accumulation = first_tally._accumulate_selections(
                first_selection.object_id,
                first_tally.selections[first_selection.object_id],
                ballot.contests[0].ballot_selections,
            )
            self.assertIsNone(bad_accumulation)

        ballot.contests[0].ballot_selections.insert(0, first_selection)

        # modify the contest description hash
        first_contest_hash = ballot.contests[0].description_hash
        ballot.contests[0].description_hash = ONE_MOD_Q
        self.assertIsNone(tally_ballot(ballot, tally))
        self.assertFalse(tally.append(ballot))

        ballot.contests[0].description_hash = first_contest_hash

        # modify a contest object id
        first_contest_object_id = ballot.contests[0].object_id
        ballot.contests[0].object_id = "a-bad-object-id"
        self.assertIsNone(tally_ballot(ballot, tally))
        self.assertFalse(tally.append(ballot))

        ballot.contests[0].object_id = first_contest_object_id

        # modify a selection object id
        first_contest_selection_object_id = (
            ballot.contests[0].ballot_selections[0].object_id)
        ballot.contests[0].ballot_selections[
            0].object_id = "another-bad-object-id"

        self.assertIsNone(tally_ballot(ballot, tally))
        self.assertFalse(tally.append(ballot))

        # Verify accumulation fails if the selection object id does not match
        if ballot.state == BallotBoxState.CAST:
            self.assertFalse(tally.contests[
                ballot.contests[0].object_id].accumulate_contest(
                    ballot.contests[0].ballot_selections))

        ballot.contests[0].ballot_selections[
            0].object_id = first_contest_selection_object_id

        # modify the ballot's hash
        first_ballot_hash = ballot.manifest_hash
        ballot.manifest_hash = ONE_MOD_Q
        self.assertIsNone(tally_ballot(ballot, tally))
        self.assertFalse(tally.append(ballot))

        ballot.manifest_hash = first_ballot_hash
        ballot.state = input_state

        return True