예제 #1
0
def test_serialize_total_shares(config, client, serverA, serverB):
    n_data = config.num_data_fields()
    data_items = bytes([(i % 3 == 1) or (i % 5 == 1) for i in range(n_data)])

    for_server_a, for_server_b = client.encode(data_items)

    vA = serverA.create_verifier(for_server_a)
    vB = serverB.create_verifier(for_server_b)

    p1A = vA.create_verify1()
    p1B = vB.create_verify1()

    p2A = vA.create_verify2(p1A, p1B)
    p2B = vB.create_verify2(p1A, p1B)

    assert vA.is_valid(p2A, p2B)
    assert vB.is_valid(p2A, p2B)

    serverA.aggregate(vA)
    serverB.aggregate(vB)

    t_a = pickle.loads(pickle.dumps(serverA.total_shares()))
    t_b = serverB.total_shares()
    output = prio.total_share_final(config, t_a, t_b)
    assert list(output) == list(data_items)
예제 #2
0
def prio_aggregate(init_func, group):
    batch_id = group.name if isinstance(group.name, str) else group.name[0]
    config, server_a, server_b = init_func(batch_id)

    prio_null_data = 0
    prio_invalid = 0

    def compact(x):
        return bytes(x.astype('uint8'))

    data = zip(group.prio_data_a.apply(compact),
               group.prio_data_b.apply(compact))
    for data_a, data_b in data:
        if not data_a or not data_b:
            prio_null_data += 1
            continue

        vA = server_a.create_verifier(data_a)
        vB = server_b.create_verifier(data_b)
        p1A = vA.create_verify1()
        p1B = vB.create_verify1()
        p2A = vA.create_verify2(p1A, p1B)
        p2B = vB.create_verify2(p1A, p1B)

        if not (vA.is_valid(p2A, p2B) and vB.is_valid(p2A, p2B)):
            prio_invalid += 1
            continue

        server_a.aggregate(vA)
        server_b.aggregate(vB)

    tA = server_a.total_shares()
    tB = server_b.total_shares()
    total = np.array(prio.total_share_final(config, tA, tB))

    # control values
    default = group.browser_is_user_default.sum()
    pdf = group.pdf_viewer_used.sum()
    newtab = group.newtab_page_enabled.sum()
    control = np.array([default, newtab, pdf])

    d = {
        "prio_control": control,
        "prio_observed": total,
        "prio_diff": (control - total).astype('int'),
        "prio_null_data": prio_null_data,
        "prio_invalid": prio_invalid,
        "count": len(group.index),
    }
    return pd.Series(d, index=d.keys())
예제 #3
0
def test_client_agg(n_clients):
    seed = prio.PRGSeed()

    skA, pkA = prio.create_keypair()
    skB, pkB = prio.create_keypair()

    # the config is shared across all actors
    config = prio.Config(133, pkA, pkB, b"test_batch")

    sA = prio.Server(config, prio.PRIO_SERVER_A, skA, seed)
    sB = prio.Server(config, prio.PRIO_SERVER_B, skB, seed)

    client = prio.Client(config)

    n_data = config.num_data_fields()
    data_items = bytes([(i % 3 == 1) or (i % 5 == 1) for i in range(n_data)])

    for i in range(n_clients):
        for_server_a, for_server_b = client.encode(data_items)

        # Setup verification
        vA = sA.create_verifier(for_server_a)
        vB = sB.create_verifier(for_server_b)

        # Produce a packet1 and send to the other party
        p1A = vA.create_verify1()
        p1B = vB.create_verify1()

        # Produce packet2 and send to the other party
        p2A = vA.create_verify2(p1A, p1B)
        p2B = vB.create_verify2(p1A, p1B)

        assert vA.is_valid(p2A, p2B)
        assert vB.is_valid(p2A, p2B)

        sA.aggregate(vA)
        sB.aggregate(vB)

    t_a = sA.total_shares()
    t_b = sB.total_shares()

    output = prio.total_share_final(config, t_a, t_b)

    expected = [item * n_clients for item in list(data_items)]
    assert (list(output) == expected)
예제 #4
0
async def main():
    n_clients = 10
    n_data = 133
    server_secret = prio.PRGSeed()
    skA, pkA = prio.create_keypair()
    skB, pkB = prio.create_keypair()

    cfg = prio.Config(n_data, pkA, pkB, b"test_batch")
    sA = prio.Server(cfg, prio.PRIO_SERVER_A, skA, server_secret)
    sB = prio.Server(cfg, prio.PRIO_SERVER_B, skB, server_secret)

    data_items = bytes([(i % 3 == 1) or (i % 5 == 1) for i in range(n_data)])

    logger.info("Starting asyncio prio pipeline.")
    client = prio.Client(cfg)
    queue_a = asyncio.Queue()
    queue_b = asyncio.Queue()

    await client_produce(client, data_items, queue_a, queue_b, n_clients)

    consumers = asyncio.ensure_future(
        asyncio.gather(
            server_consume(sA, queue_a, queue_b),
            server_consume(sB, queue_b, queue_a),
        ))

    await asyncio.gather(queue_a.join(), queue_b.join())

    t_a = sA.total_shares()
    t_b = sB.total_shares()

    output = prio.total_share_final(cfg, t_a, t_b)

    expected = [item * n_clients for item in list(data_items)]
    assert (list(output) == expected)

    consumers.cancel()
    logger.info("Done!")
예제 #5
0
vA = sA.create_verifier(for_server_a)
vB = sB.create_verifier(for_server_b)

# Produce a packet1 and send to the other party
p1A = vA.create_verify1()
p1B = vB.create_verify1()

# Produce packet2 and send to the other party
p2A = vA.create_verify2(p1A, p1B)
p2B = vB.create_verify2(p1A, p1B)

# Check validity of the request
if not vA.is_valid(p2A, p2B):
    print("data for server A is not valid!")
    sys.exit(1)
if not vB.is_valid(p2A, p2B):
    print("data for server A is not valid!")
    sys.exit(1)

sA.aggregate(vA)
sB.aggregate(vB)

# Collect from many clients and share data
tA = sA.total_shares()
tB = sB.total_shares()

output = prio.total_share_final(cfg, tA, tB)

# check the output
assert (list(data_items) == list(output))