Example #1
0
def fast_tally_ballots(
    ballots: Sequence[CiphertextBallot],
    pool: Optional[Pool] = None,
) -> TALLY_TYPE:
    """
    This function does a tally of the given list of ballots, returning a dictionary that maps
    from selection object_ids to the ElGamalCiphertext that corresponds to the encrypted tally
    of that selection. An optional `Pool` may be passed in, and it will be used to evaluate
    the ElGamal accumulation in parallel. If it's absent, then the accumulation will happen
    sequentially. Progress bars are not currently supported.
    """

    iter_count = 1
    initial_tallies: Sequence[TALLY_INPUT_TYPE] = ballots

    while True:
        if pool is None or len(initial_tallies) <= BALLOTS_PER_SHARD:
            log_and_print(
                f"tally iteration {iter_count} (FINAL): {len(initial_tallies)} partial tallies"
            )
            return sequential_tally(initial_tallies)

        shards = shard_list_uniform(initial_tallies, BALLOTS_PER_SHARD)
        log_and_print(
            f"tally iteration {iter_count}: {len(initial_tallies)} partial tallies --> {len(shards)} shards"
        )
        partial_tallies: Sequence[TALLY_TYPE] = pool.map(func=sequential_tally,
                                                         iterable=shards)

        iter_count += 1
        initial_tallies = partial_tallies
Example #2
0
 def test_shard_list_uniform(self, num_per_group: int,
                             total_inputs: int) -> None:
     inputs = list(range(1, total_inputs + 1))
     shards = shard_list_uniform(inputs, num_per_group)
     self.assertTrue(len(shards) > 0)
     min_length = min([len(x) for x in shards])
     max_length = max([len(x) for x in shards])
     self.assertTrue(min_length == max_length
                     or min_length + 1 == max_length)
     self.assertTrue(max_length <= num_per_group)
     if num_per_group > 2 and total_inputs > 1:
         self.assertTrue(min_length >= 2)
     self.assertEqual(inputs, flatmap(lambda x: x, shards))
Example #3
0
    def all_proofs_valid(
        self,
        verbose: bool = False,
        recheck_ballots_and_tallies: bool = False,
        use_progressbar: bool = True,
    ) -> bool:
        """
        Checks all the proofs used in this tally, returns True if everything is good.
        Any errors found will be logged. Normally, this only checks the proofs associated
        with the totals. If you want to also recompute the tally (i.e., tabulate the
        encrypted ballots) and verify every individual ballot proof, then set
        `recheck_ballots_and_tallies` to True.
        """

        ray_wait_for_workers(min_workers=2)

        log_and_print("Verifying proofs.", verbose)

        r_public_key = ray.put(self.context.elgamal_public_key)
        r_hash_header = ray.put(self.context.crypto_extended_base_hash)

        start = timer()
        selections = self.tally.map.values()
        sharded_selections: Sequence[
            Sequence[SelectionInfo]] = shard_list_uniform(selections, 2)

        # parallelizing this is overkill, but why not?
        results: List[bool] = ray.get([
            r_verify_tally_selection_proofs.remote(r_public_key, r_hash_header,
                                                   *s)
            for s in sharded_selections
        ])
        end = timer()

        log_and_print(f"Verification time: {end - start: .3f} sec", verbose)
        log_and_print(
            f"Verification rate: {len(self.tally.map.keys()) / (end - start): .3f} selection/sec",
            verbose,
        )

        if False in results:
            return False

        if recheck_ballots_and_tallies:
            if self.manifest is None:
                log_and_print(
                    "cannot recheck ballots and tallies without a manifest")
                return False

            # next, check each individual ballot's proofs; in this case, we're going to always
            # show the progress bar, even if verbose is false
            num_ballots = self.num_ballots

            r_manifest = ray.put(self.manifest)

            progressbar = (ProgressBar({
                "Ballots": num_ballots,
                "Tallies": num_ballots,
                "Iterations": 0,
                "Batch": 0,
            }) if use_progressbar else None)
            progressbar_actor = progressbar.actor if progressbar is not None else None

            ballot_start = timer()

            batches: Sequence[Sequence[str]] = shard_list_uniform(
                self.cvr_metadata["BallotId"], BATCH_SIZE)

            # List[ObjectRef[Optional[TALLY_TYPE]]]
            recomputed_tallies: List[ObjectRef] = []

            for batch in batches:
                if progressbar_actor:
                    progressbar_actor.update_completed.remote("Batch", 1)

                cballot_manifest_name_shards: Sequence[
                    Sequence[str]] = shard_list_uniform(
                        batch, BALLOTS_PER_SHARD)

                # List[ObjectRef[Optional[TALLY_TYPE]]]
                ballot_results: List[ObjectRef] = [
                    r_verify_ballot_proofs.remote(
                        r_manifest,
                        r_public_key,
                        r_hash_header,
                        progressbar_actor,
                        *shard,
                    ) for shard in cballot_manifest_name_shards
                ]
                # ray.wait(
                #     ballot_results,
                #     num_returns=len(cballot_manifest_name_shards),
                #     timeout=None,
                # )
                # log_and_print("Recomputing tallies.", verbose)

                ptally = ray_tally_ballots(ballot_results,
                                           PARTIAL_TALLIES_PER_SHARD,
                                           progressbar)
                recomputed_tallies.append(ptally)

            if len(recomputed_tallies) > 1:
                recomputed_tally = ray.get(
                    ray_tally_ballots(recomputed_tallies,
                                      PARTIAL_TALLIES_PER_SHARD, progressbar))
            else:
                recomputed_tally = ray.get(recomputed_tallies[0])

            if progressbar:
                progressbar.close()

            if not recomputed_tally:
                return False

            ballot_end = timer()

            log_and_print(
                f"Ballot verification rate: {num_ballots / (ballot_end - ballot_start): .3f} ballot/sec",
                True,
            )

            tally_success = tallies_match(self.tally.to_tally_map(),
                                          recomputed_tally)

            if not tally_success:
                return False

        return True
Example #4
0
def ray_tally_everything(
    cvrs: DominionCSV,
    verbose: bool = True,
    use_progressbar: bool = True,
    date: Optional[datetime] = None,
    seed_hash: Optional[ElementModQ] = None,
    master_nonce: Optional[ElementModQ] = None,
    secret_key: Optional[ElementModQ] = None,
    root_dir: Optional[str] = None,
) -> "RayTallyEverythingResults":
    """
    This top-level function takes a collection of Dominion CVRs and produces everything that
    we might want for arlo-e2e: a list of encrypted ballots, their encrypted and decrypted tally,
    and proofs of the correctness of the whole thing. The election `secret_key` is an optional
    parameter. If absent, a random keypair is generated and used. Similarly, if a `seed_hash` or
    `master_nonce` is not provided, random ones are generated and used.

    For parallelism, Ray is used. Make sure you've called `ray.init()` or `ray_localhost_init()`
    before calling this.

    If `root_dir` is specified, then the tally is written out to the specified directory, and
    the resulting `RayTallyEverythingResults` object will support the methods that allow those
    ballots to be read back in again. Conversely, if `root_dir` is `None`, then nothing is
    written to disk, and the result will not have access to individual ballots.
    """

    rows, cols = cvrs.data.shape

    ray_wait_for_workers(min_workers=2)

    if date is None:
        date = datetime.now()

    if root_dir is not None:
        mkdir_helper(root_dir, num_retries=NUM_WRITE_RETRIES)
        r_manifest_aggregator = ManifestAggregatorActor.remote(
            root_dir)  # type: ignore
    else:
        r_manifest_aggregator = None

    r_root_dir = ray.put(root_dir)

    start_time = timer()

    # Performance note: by using to_election_description_ray rather than to_election_description, we're
    # only getting back a list of dictionaries rather than a list of PlaintextBallots. We're pushing that
    # work out into the nodes, where it will run in parallel. The BallotPlaintextFactory wraps up all
    # the (immutable) state necessary to convert from these dicts to PlaintextBallots and is meant to
    # be sent to every node in the cluster.

    ed, bpf, ballot_dicts, id_map = cvrs.to_election_description_ray(date=date)
    setup_time = timer()
    num_ballots = len(ballot_dicts)
    assert num_ballots > 0, "can't have zero ballots!"
    log_and_print(
        f"ElectionGuard setup time: {setup_time - start_time: .3f} sec, {num_ballots / (setup_time - start_time):.3f} ballots/sec"
    )

    keypair = (elgamal_keypair_random() if secret_key is None else
               elgamal_keypair_from_secret(secret_key))
    assert keypair is not None, "unexpected failure with keypair computation"
    secret_key, public_key = keypair

    cec = make_ciphertext_election_context(
        number_of_guardians=1,
        quorum=1,
        elgamal_public_key=public_key,
        description_hash=ed.crypto_hash(),
    )
    r_cec = ray.put(cec)

    ied = InternalElectionDescription(ed)
    r_ied = ray.put(ied)

    if seed_hash is None:
        seed_hash = rand_q()
    r_seed_hash = ray.put(seed_hash)
    r_keypair = ray.put(keypair)

    r_ballot_plaintext_factory = ray.put(bpf)

    if master_nonce is None:
        master_nonce = rand_q()

    nonces = Nonces(master_nonce)
    r_nonces = ray.put(nonces)
    nonce_indices = range(num_ballots)

    inputs = list(zip(ballot_dicts, nonce_indices))

    batches = shard_list_uniform(inputs, BATCH_SIZE)
    num_batches = len(batches)
    log_and_print(
        f"Launching Ray.io remote encryption! (number of batches: {num_batches})"
    )

    start_time = timer()

    progressbar = (ProgressBar({
        "Ballots": num_ballots,
        "Tallies": num_ballots,
        "Iterations": 0,
        "Batch": 0,
    }) if use_progressbar else None)
    progressbar_actor = progressbar.actor if progressbar is not None else None

    batch_tallies: List[ObjectRef] = []
    for batch in batches:
        if progressbar_actor:
            progressbar_actor.update_completed.remote("Batch", 1)

        num_ballots_in_batch = len(batch)
        sharded_inputs = shard_list_uniform(batch, BALLOTS_PER_SHARD)
        num_shards = len(sharded_inputs)

        partial_tally_refs = [
            r_encrypt_and_write.remote(
                r_ied,
                r_cec,
                r_seed_hash,
                r_root_dir,
                r_manifest_aggregator,
                progressbar_actor,
                r_ballot_plaintext_factory,
                r_nonces,
                right_tuple_list(shard),
                *(left_tuple_list(shard)),
            ) for shard in sharded_inputs
        ]

        # log_and_print("Remote tallying.")
        btally = ray_tally_ballots(partial_tally_refs, BALLOTS_PER_SHARD,
                                   progressbar)
        batch_tallies.append(btally)

    # Each batch ultimately yields one partial tally; we add these up here at the
    # very end. If we have a million ballots and have batches of 10k ballots, this
    # would mean we'd have only 100 partial tallies. So, what's here works just fine.
    # If we wanted, we could certainly burn some scalar time and keep a running,
    # singular, partial tally. It's probably more important to push onward to the
    # next batch, so we can do as much work in parallel as possible.

    if len(batch_tallies) > 1:
        tally = ray.get(ray_tally_ballots(batch_tallies, 10, progressbar))
    else:
        tally = ray.get(batch_tallies[0])

    if progressbar:
        progressbar.close()

    assert tally is not None, "tally failed!"

    log_and_print("Tally decryption.")
    decrypted_tally: DECRYPT_TALLY_OUTPUT_TYPE = ray_decrypt_tally(
        tally, r_cec, r_keypair, seed_hash)

    log_and_print("Validating tally.")

    # Sanity-checking logic: make sure we don't have any unexpected keys, and that the decrypted totals
    # match up with the columns in the original plaintext data.
    tally_keys = set(decrypted_tally.keys())
    expected_keys = set(id_map.keys())

    assert tally_keys.issubset(
        expected_keys
    ), f"bad tally keys (actual keys: {sorted(tally_keys)}, expected keys: {sorted(expected_keys)})"

    for obj_id in decrypted_tally.keys():
        cvr_sum = int(cvrs.data[id_map[obj_id]].sum())
        decryption, proof = decrypted_tally[obj_id]
        assert cvr_sum == decryption, f"decryption failed for {obj_id}"

    final_manifest: Optional[Manifest] = None

    if root_dir is not None:
        final_manifest = ray.get(r_manifest_aggregator.result.remote())
        assert isinstance(
            final_manifest,
            Manifest), "type error: bad result from manifest aggregation"

    # Assemble the data structure that we're returning. Having nonces in the ciphertext makes these
    # structures sensitive for writing out to disk, but otherwise they're ready to go.
    log_and_print("Constructing results.")
    reported_tally: Dict[str, SelectionInfo] = {
        k: SelectionInfo(
            object_id=k,
            encrypted_tally=tally[k],
            # we need to forcibly convert mpz to int here to make serialization work properly
            decrypted_tally=int(decrypted_tally[k][0]),
            proof=decrypted_tally[k][1],
        )
        for k in tally.keys()
    }

    tabulate_time = timer()

    log_and_print(
        f"Encryption and tabulation: {rows} ballots, {rows / (tabulate_time - start_time): .3f} ballot/sec",
        verbose,
    )

    return RayTallyEverythingResults(
        metadata=cvrs.metadata,
        cvr_metadata=cvrs.dataframe_without_selections(),
        election_description=ed,
        num_ballots=rows,
        manifest=final_manifest,
        tally=SelectionTally(reported_tally),
        context=cec,
    )
Example #5
0
 def test_shard_list_zero_input(self) -> None:
     self.assertEqual([], shard_list([], 3))
     self.assertEqual([], shard_list_uniform([], 3))
Example #6
0
def ray_reduce_with_ray_wait(
    inputs: Iterable[ObjectRef],
    shard_size: int,
    reducer_first_arg: Any,
    reducer: Callable,  # Callable[[Any, VarArg(ObjectRef)], ObjectRef]
    progressbar: Optional[ProgressBar] = None,
    progressbar_key: Optional[str] = None,
    timeout: float = None,
    verbose: bool = False,
) -> ObjectRef:
    """
    Given a list of inputs and a Ray remote reducer, manages the Ray cluster to wait for the values
    when they're ready, and call the reducer to ultimately get down to a single value. An ObjectRef
    to that result is returned.

    The `shard_size` parameter specifies how many inputs should be fed to each call to the reducer.
    Since the available data will vary, the actual number fed to the reducer will be at least two
    and at most `shard_size`.

    The `timeout` specifies the number of seconds to wait for results to become available. If
    `shard_size*shard_size` results are available earlier, that will take precedence. Otherwise, as long as
    at least two results are available when the timeout happens, at least one reducer will be dispatched.

    (Why `shard_size*shard_size`? If `shard_size` was 10, this means we'll dispatch ten calls to the
    reducer with ten inputs each, which means fewer trips through `ray.wait`. The timeout takes
    precedence over this, guaranteeing a minimum dispatch rate.)

    The `reducer` is a Ray remote method reference that takes a given first argument of whatever
    type and then a varargs sequence of objectrefs, and returns an objectref. So, if you had
    code that looked like:

    ```
    @ray.remote
    def my_reducer(config: Config, *inputs: MyDataType) -> MyDataType:
        ...
    ```

    And let's say you're mapping some remote function to generate those values and later want
    to reduce them. That code might look like this:
    ```
    @ray.remote
    def my_mapper(input: SomethingElse) -> MyDataType:
        ...

    def run_everything(config: Config, inputs: Iterable[SomethingElse]) -> MyDataType:
        map_refs = [my_mapper.remote(i) for i in inputs]
        return ray_reduce_with_ray_wait(map_refs, 10, config, my_reducer.remote)
    ```

    If your `reducer_first_arg` corresponds to some large object that you don't want to serialize
    over and over, you could of course call `ray_put` on it first and pass that along.

    Important assumption: the `reducer` function needs to be *associative* and *commutative*.
    Ordering from the original list of inputs is *not* maintained.

    Optional feature: integration with the progressbar in `ray_progress`. Just pass in the
    ProgressBar as well as the `key` string that you want to use. Whenever more work
    is being dispatched, the progressbar's total amount of work is updated by the dispatcher here.
    The work completion notification is *not* handled here. That needs to be done by the remote
    reducer. (Why? Because it might want to update the progressbar for each element in the shard
    while here we could only see when the whole shard is completed.)
    """

    # TODO: generalize this code so the `reducer_first_arg` is wrapped up in the reducer.
    #   This seems like a job for `kwargs`. Deal with that after everything else works.

    assert (progressbar_key and progressbar
            ) or not progressbar, "progress bar requires a key string"
    assert shard_size > 1, "shard_size must be greater than one"
    assert timeout is None or timeout > 0, "negative timeouts aren't allowed"

    iteration_count = 0
    inputs = list(inputs)
    result: Optional[ObjectRef] = None

    while inputs:
        if progressbar:
            progressbar.actor.update_completed.remote("Iterations", 1)
            progressbar.print_update()

        iteration_count += 1
        # log_and_print(
        #     f"REDUCER ITERATION {iteration_count}: starting with {len(inputs)}",
        #     verbose=verbose,
        # )
        num_inputs = len(inputs)
        max_returns = shard_size * shard_size
        num_returns = max_returns if num_inputs >= max_returns else num_inputs
        tmp: Tuple[List[ObjectRef],
                   List[ObjectRef]] = ray.wait(inputs,
                                               num_returns=num_returns,
                                               timeout=timeout)
        ready_refs, pending_refs = tmp
        num_ready_refs = len(ready_refs)
        num_pending_refs = len(pending_refs)
        assert (num_inputs == num_pending_refs +
                num_ready_refs), "ray.wait fail: we lost some inputs!"

        # log_and_print(
        #     f"ray.wait() returned: ready({num_ready_refs}), pending({num_pending_refs})",
        #     verbose=verbose,
        # )

        if num_ready_refs == 1 and num_pending_refs == 0:
            # terminal case: we have one result ready and nothing pending; we're done!
            # log_and_print("Complete!", verbose=verbose)
            result = ready_refs[0]
            break
        if num_ready_refs >= 2:
            # general case: we have at least two results ready

            shards = shard_list_uniform(ready_refs, shard_size)
            size_one_shards = [s for s in shards if len(s) == 1]
            usable_shards = [s for s in shards if len(s) > 1]
            total_usable = sum(len(s) for s in usable_shards)
            # log_and_print(
            #     f"launching reduction: {total_usable} total usable values in {len(usable_shards)} shards, {len(size_one_shards)} size-one shards",
            #     verbose=verbose,
            # )

            if progressbar:
                progressbar.actor.update_total.remote(progressbar_key,
                                                      total_usable)

            # dispatches jobs to remote workers, returns immediately with ObjectRefs
            partial_results = [
                reducer(reducer_first_arg, *s) for s in usable_shards
            ]

            inputs = list(partial_results + pending_refs +
                          [x[0] for x in size_one_shards])

            assert len(
                inputs
            ) < num_inputs, "reducer fail: we didn't shrink the inputs"
        else:
            # annoying case: we have exactly one result and nothing useful to do with it
            pass

    assert result is not None, "reducer fail: somehow exited the loop with no result"
    return result
Example #7
0
def ray_reduce_with_rounds(
    inputs: Iterable[ObjectRef],
    shard_size: int,
    reducer_first_arg: Any,
    reducer: Callable,  # Callable[[Any, VarArg(ObjectRef)], ObjectRef]
    progressbar: Optional[ProgressBar] = None,
    progressbar_key: Optional[str] = None,
    verbose: bool = False,
) -> ObjectRef:
    """
    Given a list of inputs and a Ray remote reducer, manages the Ray cluster to wait for the values
    when they're ready, and call the reducer to ultimately get down to a single value. Unlike
    `ray_reduce_with_ray_wait`, this version builds a reduction tree. It depends on an associative
    property for the reducer, but not a commutative property.

    The `shard_size` parameter specifies how many inputs should be fed to each call to the reducer.
    Since the available data will vary, the actual number fed to the reducer will be at least two
    and at most `shard_size`.

    The `reducer` is a Ray remote method reference that takes a given first argument of whatever
    type and then a varargs sequence of objectrefs, and returns an objectref. So, if you had
    code that looked like:

    ```
    @ray.remote
    def my_reducer(config: Config, *inputs: MyDataType) -> MyDataType:
        ...
    ```

    And let's say you're mapping some remote function to generate those values and later want
    to reduce them. That code might look like this:
    ```
    @ray.remote
    def my_mapper(input: SomethingElse) -> MyDataType:
        ...

    def run_everything(config: Config, inputs: Iterable[SomethingElse]) -> MyDataType:
        map_refs = [my_mapper.remote(i) for i in inputs]
        return ray_reduce_with_rounds(map_refs, 10, config, my_reducer.remote)
    ```

    If your `reducer_first_arg` corresponds to some large object that you don't want to serialize
    over and over, you could of course call `ray_put` on it first and pass that along.

    Optional feature: integration with the progressbar in `ray_progress`. Just pass in the
    ProgressBar as well as the `key` string that you want to use. Whenever more work
    is being dispatched, the progressbar's total amount of work is updated by the dispatcher here.
    The work completion notification is *not* handled here. That needs to be done by the remote
    reducer. (Why? Because it might want to update the progressbar for each element in the shard
    while here we could only see when the whole shard is completed.)
    """

    # TODO: generalize this code so the `reducer_first_arg` is wrapped up in the reducer.
    #   This seems like a job for `kwargs`. Deal with that after everything else works.

    assert (progressbar_key and progressbar
            ) or not progressbar, "progress bar requires a key string"

    assert shard_size > 1, "shard_size must be greater than one"

    progressbar_actor = progressbar.actor if progressbar is not None else None
    iter_count = 1

    result: Optional[ObjectRef] = None

    inputs = list(inputs)

    while True:
        num_inputs = len(inputs)

        if progressbar_actor is not None:
            progressbar_actor.update_completed.remote("Iterations", 1)
            progressbar_actor.update_total.remote(progressbar_key, num_inputs)

        if num_inputs <= shard_size:
            log_and_print(f"Reduction (FINAL): {num_inputs} partial results",
                          verbose=verbose)
            result = reducer(reducer_first_arg, *inputs)
            break

        # Sequence[Sequence[ObjectRef[Optional[TALLY_TYPE]]]]
        shards: Sequence[Sequence[ObjectRef]] = shard_list_uniform(
            inputs, shard_size)

        log_and_print(
            f"Reduction {iter_count:2d}: {num_inputs:6d} partial results --> {len(shards)} shards (bps = {shard_size})",
            verbose=verbose,
        )

        # Sequence[ObjectRef[Optional[TALLY_TYPE]]]
        partial_results: List[ObjectRef] = [
            reducer(reducer_first_arg, *shard) for shard in shards
        ]

        # To avoid deeply nested tasks, we're going to wait for this to finish.
        # If you comment out the call to ray.wait(), everything still works, but
        # you can get warnings about too many tasks.
        # ray.wait(partial_results, num_returns=len(partial_results), timeout=None)

        iter_count += 1
        inputs = partial_results

    if progressbar:
        progressbar.print_until_done()
    assert result is not None, "while loop shouldn't have broken without setting result"
    return result