예제 #1
0
    def _process(share, internal, external):
        private_key, public_key_internal, public_key_external = import_keys(
            private_key_hex, public_key_hex_internal, public_key_hex_external)

        config = libprio.PrioConfig_new(n_data, public_key_internal,
                                        public_key_external, batch_id)
        server = libprio.PrioServer_new(config, match_server(server_id),
                                        private_key, shared_secret)
        verifier = libprio.PrioVerifier_new(server)

        packet1_internal = libprio.PrioPacketVerify1_new()
        packet1_external = libprio.PrioPacketVerify1_new()
        packet = libprio.PrioPacketVerify2_new()
        try:
            libprio.PrioVerifier_set_data(verifier, bytes(share))
            libprio.PrioPacketVerify1_read(packet1_internal, bytes(internal),
                                           config)
            libprio.PrioPacketVerify1_read(packet1_external, bytes(external),
                                           config)
            libprio.PrioPacketVerify2_set_data(packet, verifier,
                                               packet1_internal,
                                               packet1_external)
            return libprio.PrioPacketVerify2_write(packet)
        except (RuntimeError, ValueError, TypeError):
            pass
        return None
예제 #2
0
def test_client_agg(n_clients):
    seed = prio.PrioPRGSeed_randomize()

    skA, pkA = prio.Keypair_new()
    skB, pkB = prio.Keypair_new()
    cfg = prio.PrioConfig_new(133, pkA, pkB, b"test_batch")
    sA = prio.PrioServer_new(cfg, prio.PRIO_SERVER_A, skA, seed)
    sB = prio.PrioServer_new(cfg, prio.PRIO_SERVER_B, skB, seed)
    vA = prio.PrioVerifier_new(sA)
    vB = prio.PrioVerifier_new(sB)
    tA = prio.PrioTotalShare_new()
    tB = prio.PrioTotalShare_new()

    n_data = prio.PrioConfig_numDataFields(cfg)
    data_items = bytes([(i % 3 == 1) or (i % 5 == 1) for i in range(n_data)])

    for i in range(n_clients):
        for_server_a, for_server_b = prio.PrioClient_encode(cfg, data_items)

        prio.PrioVerifier_set_data(vA, for_server_a)
        prio.PrioVerifier_set_data(vB, for_server_b)

        prio.PrioServer_aggregate(sA, vA)
        prio.PrioServer_aggregate(sB, vB)

    prio.PrioTotalShare_set_data(tA, sA)
    prio.PrioTotalShare_set_data(tB, sB)

    output = array.array('L', prio.PrioTotalShare_final(cfg, tA, tB))

    expected = [item * n_clients for item in list(data_items)]
    assert (list(output) == expected)
예제 #3
0
def aggregate(
    batch_id: bytes,
    n_data: int,
    server_id: str,
    private_key_hex: bytes,
    shared_secret: bytes,
    public_key_hex_internal: bytes,
    public_key_hex_external: bytes,
    pdf: pd.DataFrame,
) -> pd.DataFrame:
    """This method is unique from the others because it relies on the use of a
    grouped map. This requires the use of Dataframe.applyInPandas which takes
    two parameters.

    schema: payload binary, error int, total int
    """
    libprio.Prio_init()
    private_key, public_key_internal, public_key_external = import_keys(
        private_key_hex, public_key_hex_internal, public_key_hex_external)

    config = libprio.PrioConfig_new(n_data, public_key_internal,
                                    public_key_external, batch_id)
    server = libprio.PrioServer_new(config, match_server(server_id),
                                    private_key, shared_secret)
    verifier = libprio.PrioVerifier_new(server)
    packet2_internal = libprio.PrioPacketVerify2_new()
    packet2_external = libprio.PrioPacketVerify2_new()
    # assumes each row in the iterator contains the literal information
    # necessary for processing the data. It should have a input, input_internal,
    # and input_external row
    total, error = 0, 0
    error_counter = Counter()
    for share, internal, external in zip(pdf.shares, pdf.internal,
                                         pdf.external):
        total += 1
        try:
            libprio.PrioVerifier_set_data(verifier, share)
            libprio.PrioPacketVerify2_read(packet2_internal, internal, config)
            libprio.PrioPacketVerify2_read(packet2_external, external, config)
            libprio.PrioVerifier_isValid(verifier, packet2_internal,
                                         packet2_external)
            libprio.PrioServer_aggregate(server, verifier)
        except (RuntimeError, ValueError, TypeError) as e:
            error += 1
            error_counter.update([f"server {server_id}: {e}"])
    logger.warning(error_counter)
    result = [
        dict(payload=libprio.PrioServer_write(server),
             error=error,
             total=total)
    ]
    libprio.Prio_clear()
    return pd.DataFrame(result)
예제 #4
0
 def _process(share):
     private_key, public_key_internal, public_key_external = import_keys(
         private_key_hex, public_key_hex_internal, public_key_hex_external)
     config = libprio.PrioConfig_new(n_data, public_key_internal,
                                     public_key_external, batch_id)
     server = libprio.PrioServer_new(config, match_server(server_id),
                                     private_key, shared_secret)
     verifier = libprio.PrioVerifier_new(server)
     packet = libprio.PrioPacketVerify1_new()
     try:
         # share is actually a bytearray, so convert it into bytes
         libprio.PrioVerifier_set_data(verifier, bytes(share))
         libprio.PrioPacketVerify1_set_data(packet, verifier)
         return libprio.PrioPacketVerify1_write(packet)
     except (RuntimeError, ValueError, TypeError):
         pass
     return None
예제 #5
0
    def _process(internal, external):
        _, public_key_internal, public_key_external = import_keys(
            private_key_hex, public_key_hex_internal, public_key_hex_external)
        config = libprio.PrioConfig_new(n_data, public_key_internal,
                                        public_key_external, batch_id)
        share_internal = libprio.PrioTotalShare_new()
        share_external = libprio.PrioTotalShare_new()
        libprio.PrioTotalShare_read(share_internal, internal, config)
        libprio.PrioTotalShare_read(share_external, external, config)

        # ordering matters
        if match_server(server_id) == libprio.PRIO_SERVER_B:
            share_internal, share_external = share_external, share_internal

        total_share = libprio.PrioTotalShare_final(config, share_internal,
                                                   share_external)
        return list(array.array("L", total_share))
예제 #6
0
def encode_single(
    batch_id: str,
    n_data: int,
    public_key_hex_internal: bytearray,
    public_key_hex_external: bytearray,
    payload: list,
):
    libprio.Prio_init()
    public_key_internal, public_key_external = import_public_keys(
        bytes(public_key_hex_internal), bytes(public_key_hex_external))
    config = libprio.PrioConfig_new(n_data, public_key_internal,
                                    public_key_external, batch_id.encode())
    try:
        a, b = libprio.PrioClient_encode(config, bytes(list(map(int,
                                                                payload))))
    except (RuntimeError, ValueError, TypeError):
        a, b = None, None
    return dict(a=a, b=b)
예제 #7
0
def encode(
    batch_id: bytes,
    n_data: int,
    public_key_hex_internal: bytes,
    public_key_hex_external: bytes,
    payload: pd.Series,
) -> pd.DataFrame:
    libprio.Prio_init()
    public_key_internal, public_key_external = import_public_keys(
        public_key_hex_internal, public_key_hex_external)
    config = libprio.PrioConfig_new(n_data, public_key_internal,
                                    public_key_external, batch_id)
    results = []
    for data in payload:
        # trying to encode the integer into a bit array
        try:
            a, b = libprio.PrioClient_encode(config, bytes(list(map(int,
                                                                    data))))
        except (RuntimeError, ValueError, TypeError):
            a, b = None, None
        results.append(dict(a=a, b=b))
    libprio.Prio_clear()
    return pd.DataFrame(results)
예제 #8
0
def total_share(
    batch_id: bytes,
    n_data: int,
    server_id: str,
    private_key_hex: bytes,
    shared_secret: bytes,
    public_key_hex_internal: bytes,
    public_key_hex_external: bytes,
    pdf: pd.DataFrame,
) -> pd.DataFrame:
    """schema: payload binary, error int, total int"""
    libprio.Prio_init()
    private_key, public_key_internal, public_key_external = import_keys(
        private_key_hex, public_key_hex_internal, public_key_hex_external)
    config = libprio.PrioConfig_new(n_data, public_key_internal,
                                    public_key_external, batch_id)
    server = libprio.PrioServer_new(config, match_server(server_id),
                                    private_key, shared_secret)
    server_i = libprio.PrioServer_new(config, match_server(server_id),
                                      private_key, shared_secret)
    # NOTE: this breaks expectations from other udfs, which expects shares,
    # internal, external etc.
    for aggregates in pdf["payload"]:
        libprio.PrioServer_read(server_i, aggregates, config)
        libprio.PrioServer_merge(server, server_i)

    total_share = libprio.PrioTotalShare_new()
    libprio.PrioTotalShare_set_data(total_share, server)
    result = [
        dict(
            payload=libprio.PrioTotalShare_write(total_share),
            error=pdf.error.sum(),
            total=pdf.total.sum(),
        )
    ]
    libprio.Prio_clear()
    return pd.DataFrame(result)
예제 #9
0
 def __init__(self, n_fields, server_a, server_b, batch_id):
     self.instance = libprio.PrioConfig_new(n_fields, server_a.instance,
                                            server_b.instance, batch_id)