def test_select_shares_exception_ring(get_clients) -> None: parties = get_clients(3) falcon = Falcon() session = Session(parties=parties, protocol=falcon, ring_size=2**32) SessionManager.setup_mpc(session) val = MPCTensor(secret=1, session=session) with pytest.raises(ValueError): Falcon.select_shares(val, val, val)
def test_select_shares_exception_shape(get_clients) -> None: parties = get_clients(3) falcon = Falcon() session = Session(parties=parties, protocol=falcon) SessionManager.setup_mpc(session) val = MPCTensor(secret=1, session=session) rst = val.share_ptrs[0].get_copy() rst.ring_size = 2 val.share_ptrs[0] = rst.send(parties[0]) val.shape = None with pytest.raises(ValueError): Falcon.select_shares(val, val, val)
def test_select_shares(get_clients, security) -> None: parties = get_clients(3) falcon = Falcon(security) session = Session(parties=parties, protocol=falcon) SessionManager.setup_mpc(session) sh = torch.tensor([[1, 0], [0, 1]], dtype=torch.bool) shares = [sh, sh, sh] ptr_lst = ReplicatedSharedTensor.distribute_shares(shares, session, ring_size=2) b = MPCTensor(shares=ptr_lst, session=session, shape=sh.shape) x_val = torch.tensor([[1, 2], [3, 4]]) y_val = torch.tensor([[5, 6], [7, 8]]) x = MPCTensor(secret=x_val, session=session) y = MPCTensor(secret=y_val, session=session) z = Falcon.select_shares(x, y, b) expected_res = torch.tensor([[5.0, 2.0], [3.0, 8.0]]) assert (expected_res == z.reconstruct()).all()