Esempio n. 1
0
def test_select_shares_exception_ring(get_clients) -> None:
    parties = get_clients(3)
    falcon = Falcon()
    session = Session(parties=parties, protocol=falcon, ring_size=2**32)
    SessionManager.setup_mpc(session)
    val = MPCTensor(secret=1, session=session)
    with pytest.raises(ValueError):
        Falcon.select_shares(val, val, val)
Esempio n. 2
0
def test_select_shares_exception_shape(get_clients) -> None:
    parties = get_clients(3)
    falcon = Falcon()
    session = Session(parties=parties, protocol=falcon)
    SessionManager.setup_mpc(session)
    val = MPCTensor(secret=1, session=session)
    rst = val.share_ptrs[0].get_copy()
    rst.ring_size = 2
    val.share_ptrs[0] = rst.send(parties[0])
    val.shape = None
    with pytest.raises(ValueError):
        Falcon.select_shares(val, val, val)
Esempio n. 3
0
def test_select_shares(get_clients, security) -> None:
    parties = get_clients(3)
    falcon = Falcon(security)
    session = Session(parties=parties, protocol=falcon)
    SessionManager.setup_mpc(session)
    sh = torch.tensor([[1, 0], [0, 1]], dtype=torch.bool)
    shares = [sh, sh, sh]
    ptr_lst = ReplicatedSharedTensor.distribute_shares(shares,
                                                       session,
                                                       ring_size=2)
    b = MPCTensor(shares=ptr_lst, session=session, shape=sh.shape)

    x_val = torch.tensor([[1, 2], [3, 4]])
    y_val = torch.tensor([[5, 6], [7, 8]])
    x = MPCTensor(secret=x_val, session=session)
    y = MPCTensor(secret=y_val, session=session)

    z = Falcon.select_shares(x, y, b)

    expected_res = torch.tensor([[5.0, 2.0], [3.0, 8.0]])

    assert (expected_res == z.reconstruct()).all()