def test_mse_loss(get_clients, reduction) -> None: clients = get_clients(4) session = Session(parties=clients) SessionManager.setup_mpc(session) y_secret = torch.Tensor([0.23, 0.32, 0.2, 0.3]) y_mpc = MPCTensor(secret=y_secret, session=session) y_pred = torch.Tensor([0.1, 0.3, 0.4, 0.2]) y_pred_mpc = MPCTensor(secret=y_pred, session=session) res = mse_loss(y_mpc, y_pred_mpc, reduction) res_expected = torch.nn.functional.mse_loss(y_secret, y_pred, reduction=reduction) assert np.allclose(res.reconstruct(), res_expected, atol=1e-4)
def test_session_manager_init(): """Test correct initialisation of the SessionManager class.""" # Test default init session = SessionManager() assert isinstance(session.uuid, UUID) # Test custom init uuid = uuid4() session = Session(uuid=uuid) assert session.uuid == uuid
def test_prrs_rst_ring_size(get_clients) -> None: clients = get_clients(3) falcon = Falcon() session = Session(protocol=falcon, parties=clients) SessionManager.setup_mpc(session) for ring_size in RING_SIZE_TO_TYPE.keys(): rst_pt0 = session.session_ptrs[0].prrs_generate_random_share( shape=(1, 2), ring_size=str(ring_size)) share = rst_pt0.get_copy() assert share.ring_size == ring_size assert share.shares[0].dtype == RING_SIZE_TO_TYPE[ring_size] assert share.shares[1].dtype == RING_SIZE_TO_TYPE[ring_size] if ring_size == PRIME_NUMBER: assert torch.max(torch.cat(share.shares)) <= PRIME_NUMBER - 1 assert torch.min(torch.cat(share.shares)) >= 0
def test_max_dim(dim, keepdim, get_clients) -> None: clients = get_clients(2) session = Session(parties=clients) SessionManager.setup_mpc(session) secret = torch.Tensor([[[1, 2], [3, -1], [4, 5]], [[2, 5], [5, 1], [6, 42]]]) x = MPCTensor(secret=secret, session=session) max_val, max_idx_val = x.max(dim=dim, keepdim=keepdim) assert isinstance(x, MPCTensor), "Expected argmax to be MPCTensor" res_idx = max_idx_val.reconstruct() res_max = max_val.reconstruct() expected_max, expected_indices = secret.max(dim=dim, keepdim=keepdim) assert ( res_idx == expected_indices ).all(), f"Expected indices for maximum to be {expected_indices}" assert (res_max == expected_max).all(), f"Expected argmax to be {expected_max}"
def test_ops_mpc_public(get_clients, nr_clients, op_str) -> None: clients = get_clients(nr_clients) session = Session(parties=clients) SessionManager.setup_mpc(session) x_secret = torch.Tensor([[0.125, -1.25], [-4.25, 4]]) if op_str == "truediv": y_secret = torch.Tensor([[2, 3], [4, 5]]).long() else: y_secret = torch.Tensor([[4.5, -2.5], [5, 2.25]]) x = MPCTensor(secret=x_secret, session=session) op = getattr(operator, op_str) expected_result = op(x_secret, y_secret) result = op(x, y_secret).reconstruct() assert np.allclose(result, expected_result, atol=10e-4)
def test_ops_public_mul_integer_parties(get_clients, parties, security): config = Config(encoder_base=1, encoder_precision=0) parties = get_clients(parties) protocol = Falcon(security) session = Session(protocol=protocol, parties=parties, config=config) SessionManager.setup_mpc(session) secret = torch.tensor([[-100, 20, 30], [-90, 1000, 1], [1032, -323, 15]]) value = 8 op = getattr(operator, "mul") tensor = MPCTensor(secret=secret, session=session) shares = [op(share, value) for share in tensor.share_ptrs] result = MPCTensor(shares=shares, session=session) assert (result.reconstruct() == (secret * value)).all()
def test_conv_mpc_mpc(get_clients, nr_clients, bias, stride, padding, op_str) -> None: clients = get_clients(nr_clients) session = Session(parties=clients) SessionManager.setup_mpc(session) input_secret = torch.ones(1, 1, 4, 4) weight_secret = torch.ones(1, 1, 2, 2) input = MPCTensor(secret=input_secret, session=session) weight = MPCTensor(secret=weight_secret, session=session) kwargs = {"bias": bias, "stride": stride, "padding": padding} op = getattr(MPCTensor, op_str) result = op(input, weight, **kwargs).reconstruct() op = getattr(torch.nn.functional, op_str) expected_result = op(input_secret, weight_secret, **kwargs) assert np.allclose(result, expected_result, rtol=10e-4)
def test_mul_private_matrix(get_clients, security, base, precision): parties = get_clients(3) protocol = Falcon(security) config = Config(encoder_base=base, encoder_precision=precision) session = Session(protocol=protocol, parties=parties, config=config) SessionManager.setup_mpc(session) secret1 = torch.tensor([[-100.25, 20.3, 30.12], [-50.1, 100.217, 1.2], [1032.15, -323.56, 15.15]]) secret2 = torch.tensor([[-1, 0.28, 3], [-9, 10.18, 1], [32, -23, 5]]) tensor1 = MPCTensor(secret=secret1, session=session) tensor2 = MPCTensor(secret=secret2, session=session) result = tensor1 * tensor2 expected_res = secret1 * secret2 assert np.allclose(result.reconstruct(), expected_res, atol=1e-3)
def test_eq() -> None: """Test __eq__ for Session.""" session = Session() other1 = Session() other2 = session # Test different instances: assert session != 1 # Test equal sessions: assert session == other2 # Test same sessions (until we call setup mpc): assert session == other1 SessionManager.setup_mpc(session) assert session != other1
def test_reconstruct(get_clients) -> None: clients = get_clients(2) session = Session(parties=clients) SessionManager.setup_mpc(session) a_rand = 3 a = ShareTensor(data=a_rand, config=Config(encoder_precision=0)) MPCTensor.generate_shares(secret=a, nr_parties=2, tensor_type=torch.long) MPCTensor.generate_shares( secret=a_rand, nr_parties=2, config=Config(), tensor_type=torch.long ) x_secret = torch.Tensor([1, -2, 3.0907, -4.870]) x = MPCTensor(secret=x_secret, session=session) x = x.reconstruct() assert np.allclose(x_secret, x)
def test_bit_decomposition_ttp(get_clients, security_type) -> None: parties = get_clients(3) falcon = Falcon(security_type=security_type) session = Session(parties=parties, protocol=falcon) SessionManager.setup_mpc(session) secret = torch.tensor([[-1, 12], [-32, 45], [98, -5624]]) x = MPCTensor(secret=secret, session=session) b_sh = ABY3.bit_decomposition_ttp(x, session) ring_size = x.session.ring_size tensor_type = x.session.tensor_type ring_bits = get_nr_bits(ring_size) result = torch.zeros(size=x.shape, dtype=tensor_type) for i in range(ring_bits): result |= b_sh[i].reconstruct(decode=False).type(tensor_type) << i exp_res = torch.tensor([[-65536, 786432], [-2097152, 2949120], [6422528, -368574464]]) assert (result == exp_res).all()
def test_truncation_algorithm1(get_clients, base, precision) -> None: parties = get_clients(3) falcon = Falcon("semi-honest") config = Config(encoder_base=base, encoder_precision=precision) session = Session(parties=parties, protocol=falcon, config=config) SessionManager.setup_mpc(session) x = torch.tensor([[1.24, 4.51, 6.87], [7.87, 1301, 541]]) x_mpc = MPCTensor(secret=x, session=session) result = ABY3.truncate(x_mpc, session, session.ring_size, session.config) fp_encoder = FixedPointEncoder(base=session.config.encoder_base, precision=session.config.encoder_precision) expected_res = x_mpc.reconstruct(decode=False) // fp_encoder.scale expected_res = fp_encoder.decode(expected_res) assert np.allclose(result.reconstruct(), expected_res, atol=1e-3)
def test_send_get(get_clients, precision=12, base=4) -> None: client = get_clients(1)[0] protocol = Falcon("semi-honest") session = Session(protocol=protocol, parties=[client]) SessionManager.setup_mpc(session) share1 = torch.Tensor([1.4, 2.34, 3.43]) share2 = torch.Tensor([1, 2, 3]) share3 = torch.Tensor([1.4, 2.34, 3.43]) session_uuid = session.rank_to_uuid[0] x_share = ReplicatedSharedTensor(shares=[share1, share2, share3], session_uuid=session_uuid) x_ptr = x_share.send(client) result = x_ptr.get() assert result == x_share
def test_primitive_logging_beaver_matmul(get_clients) -> None: clients = get_clients(2) session = Session(parties=clients) SessionManager.setup_mpc(session) p_kwargs = {"a_shape": (2, 3), "b_shape": (3, 10)} g_kwargs = {"a_shape": (2, 3), "b_shape": (3, 10), "nr_parties": 2} CryptoPrimitiveProvider.start_logging() CryptoPrimitiveProvider.generate_primitives( session=session, op_str="beaver_matmul", p_kwargs=p_kwargs, g_kwargs=g_kwargs, ) primitive_log = CryptoPrimitiveProvider.stop_logging() expected_log = {"beaver_matmul": [(p_kwargs, g_kwargs)]} assert expected_log == primitive_log
def test_backward_without_requires_grad(get_clients): clients = get_clients(4) session = Session(parties=clients) session.autograd_active = True SessionManager.setup_mpc(session) x_secret = torch.tensor([[0.125, -1.25], [-4.25, 4], [-3, 3]]) y_secret = torch.tensor([[4.5, -2.5], [5, 2.25], [-3, 3]]) x = MPCTensor(secret=x_secret, session=session) y = MPCTensor(secret=y_secret, session=session) res_mpc = x - y s_mpc = res_mpc.sum() s_mpc.backward() assert not res_mpc.requires_grad assert res_mpc.grad is None assert x.grad is None assert y.grad is None
def test_prrs_share_tensor() -> None: """Test przs_generate_random_share method from Session for ShareTensor.""" session = Session() # default protocol: FSS SessionManager.setup_mpc(session) seed1 = secrets.randbits(32) seed2 = secrets.randbits(32) gen1 = get_new_generator(seed1) gen2 = get_new_generator(seed2) session.przs_generators = [gen1, gen2] shape = (2, 1) share = session.prrs_generate_random_share(shape=shape) assert isinstance(share, ShareTensor) new_gen1 = get_new_generator(seed1) share1 = generate_random_element(generator=new_gen1, shape=shape, tensor_type=session.tensor_type) target_tensor = share1 assert (share.tensor == target_tensor).all()
def test_generate_primitive_from_dict_beaver_conv2d(get_clients) -> None: clients = get_clients(2) session = Session(parties=clients) SessionManager.setup_mpc(session) primitive_log = { "beaver_conv2d": [( { "a_shape": (1, 1, 28, 28), "b_shape": (5, 1, 5, 5) }, { "session": session, "a_shape": (1, 1, 28, 28), "b_shape": (5, 1, 5, 5), "nr_parties": 2, }, )] } CryptoPrimitiveProvider.generate_primitive_from_dict( primitive_log=primitive_log, session=session) a_shape = (1, 1, 28, 28) b_shape = (5, 1, 5, 5) key = f"beaver_conv2d_{a_shape}_{b_shape}" store_client_1 = session.session_ptrs[0].crypto_store.store.get() store_client_2 = session.session_ptrs[1].crypto_store.store.get() a_shape_client_1 = tuple(store_client_1.get(key)[0][0].shape) b_shape_client_1 = tuple(store_client_1.get(key)[0][1].shape) assert a_shape == a_shape_client_1 assert b_shape == b_shape_client_1 a_shape_client_2 = tuple(store_client_2.get(key)[0][0].shape) b_shape_client_2 = tuple(store_client_2.get(key)[0][1].shape) assert a_shape == a_shape_client_2 assert b_shape == b_shape_client_2
def test_generate_primitive(get_clients: Callable, nr_parties: int, nr_instances: int) -> None: parties = get_clients(nr_parties) session = Session(parties=parties) SessionManager.setup_mpc(session) g_kwargs = {"nr_parties": nr_parties, "nr_instances": nr_instances} res = CryptoPrimitiveProvider.generate_primitives( "test", sessions=session.session_ptrs, g_kwargs=g_kwargs, p_kwargs=None, ) assert isinstance(res, list) assert len(res) == nr_parties for i, primitives in enumerate(res): for primitive in primitives: assert primitive == tuple(i for _ in range(PRIMITIVE_NR_ELEMS))
def test_generate_and_transfer_primitive( get_clients: Callable, nr_parties: int, nr_instances: int ) -> None: parties = get_clients(nr_parties) session = Session(parties=parties) SessionManager.setup_mpc(session) g_kwargs = {"nr_parties": nr_parties} CryptoPrimitiveProvider.generate_primitives( "test", n_instances=nr_instances, sessions=session.session_ptrs, g_kwargs=g_kwargs, p_kwargs={}, ) for i in range(nr_parties): remote_crypto_store = session.session_ptrs[i].crypto_store primitives = remote_crypto_store.get_primitives_from_store("test").get() assert primitives == [tuple(i for _ in range(PRIMITIVE_NR_ELEMS))]
def test_generate_primitive_from_dict_beaver_matmul(get_clients) -> None: clients = get_clients(2) session = Session(parties=clients) SessionManager.setup_mpc(session) primitive_log = { "beaver_matmul": [( { "a_shape": (2, 3), "b_shape": (3, 10) }, { "a_shape": (2, 3), "b_shape": (3, 10), "nr_parties": 2 }, )] } CryptoPrimitiveProvider.generate_primitive_from_dict( primitive_log=primitive_log, session=session) a_shape = (2, 3) b_shape = (3, 10) key = f"beaver_matmul_{a_shape}_{b_shape}" store_client_1 = session.session_ptrs[0].crypto_store.store.get() store_client_2 = session.session_ptrs[1].crypto_store.store.get() a_shape_client_1 = tuple(store_client_1.get(key)[0][0].shape) b_shape_client_1 = tuple(store_client_1.get(key)[0][1].shape) assert a_shape == a_shape_client_1 assert b_shape == b_shape_client_1 a_shape_client_2 = tuple(store_client_2.get(key)[0][0].shape) b_shape_client_2 = tuple(store_client_2.get(key)[0][1].shape) assert a_shape == a_shape_client_2 assert b_shape == b_shape_client_2
def test_load_sympc() -> None: alice = sy.VirtualMachine() alice_client = alice.get_root_client() bob = sy.VirtualMachine() bob_client = bob.get_root_client() # third party from sympc.session import Session from sympc.session import SessionManager from sympc.tensor import MPCTensor sy.load("sympc") session = Session(parties=[alice_client, bob_client]) SessionManager.setup_mpc(session) y = th.Tensor([-5, 0, 1, 2, 3]) x_secret = th.Tensor([30]) x = MPCTensor(secret=x_secret, shape=(1,), session=session) assert ((x + y).reconstruct() == th.Tensor([25.0, 30.0, 31.0, 32.0, 33.0])).all()
def test_prime_xor(get_clients, security, bit) -> None: parties = get_clients(3) protocol = Falcon(security) session = Session(protocol=protocol, parties=parties) session.ring_size = PRIME_NUMBER SessionManager.setup_mpc(session) ring_size = PRIME_NUMBER x_sh1 = torch.tensor([[17, 44], [8, 20]], dtype=torch.uint8) x_sh2 = torch.tensor([[8, 51], [27, 52]], dtype=torch.uint8) x_sh3 = torch.tensor([[42, 40], [32, 63]], dtype=torch.uint8) bit_sh_1, bit_sh_2, bit_sh_3 = bit b_sh1 = torch.tensor([bit_sh_1], dtype=torch.uint8) b_sh2 = torch.tensor([bit_sh_2], dtype=torch.uint8) b_sh3 = torch.tensor([bit_sh_3], dtype=torch.uint8) shares_x = [x_sh1, x_sh2, x_sh3] shares_b = [b_sh1, b_sh2, b_sh3] rst_list_x = ReplicatedSharedTensor.distribute_shares(shares=shares_x, session=session, ring_size=ring_size) rst_list_b = ReplicatedSharedTensor.distribute_shares(shares=shares_b, session=session, ring_size=ring_size) x = MPCTensor(shares=rst_list_x, session=session) b = MPCTensor(shares=rst_list_b, session=session) x.shape = x_sh1.shape b.shape = b_sh1.shape secret_x = ReplicatedSharedTensor.shares_sum(shares_x, ring_size) secret_b = ReplicatedSharedTensor.shares_sum(shares_b, ring_size) result = operator.xor(x, b) expected_res = secret_x ^ secret_b assert (result.reconstruct(decode=False) == expected_res).all()
def test_backward(get_clients): clients = get_clients(4) session = Session(parties=clients) session.autograd_active = True SessionManager.setup_mpc(session) x_secret = torch.tensor([[0.125, -1.25], [-4.25, 4], [-3, 3]], requires_grad=True) y_secret = torch.tensor([[4.5, -2.5], [5, 2.25], [-3, 3]], requires_grad=True) x = MPCTensor(secret=x_secret, session=session, requires_grad=True) y = MPCTensor(secret=y_secret, session=session, requires_grad=True) res_mpc = x * y res = x_secret * y_secret s_mpc = res_mpc.sum() s = torch.sum(res) s_mpc.backward() s.backward() assert np.allclose(x.grad.get(), x_secret.grad, rtol=1e-3) assert np.allclose(y.grad.get(), y_secret.grad, rtol=1e-3)
def test_select_shares(get_clients, security) -> None: parties = get_clients(3) falcon = Falcon(security) session = Session(parties=parties, protocol=falcon) SessionManager.setup_mpc(session) sh = torch.tensor([[1, 0], [0, 1]], dtype=torch.bool) shares = [sh, sh, sh] ptr_lst = ReplicatedSharedTensor.distribute_shares(shares, session, ring_size=2) b = MPCTensor(shares=ptr_lst, session=session, shape=sh.shape) x_val = torch.tensor([[1, 2], [3, 4]]) y_val = torch.tensor([[5, 6], [7, 8]]) x = MPCTensor(secret=x_val, session=session) y = MPCTensor(secret=y_val, session=session) z = Falcon.select_shares(x, y, b) expected_res = torch.tensor([[5.0, 2.0], [3.0, 8.0]]) assert (expected_res == z.reconstruct()).all()
def test_private_compare(get_clients, security) -> None: parties = get_clients(3) falcon = Falcon(security_type=security) session = Session(parties=parties, protocol=falcon) SessionManager.setup_mpc(session) base = session.config.encoder_base precision = session.config.encoder_precision fp_encoder = FixedPointEncoder(base=base, precision=precision) secret = torch.tensor([[358.85, 79.29], [67.78, 2415.50]]) r = torch.tensor([[357.05, 90], [145.32, 2400.54]]) r = fp_encoder.encode(r) x = MPCTensor(secret=secret, session=session) x_b = ABY3.bit_decomposition_ttp(x, session) # bit shares x_p = [] # prime ring shares for share in x_b: x_p.append(ABY3.bit_injection(share, session, PRIME_NUMBER)) tensor_type = get_type_from_ring(session.ring_size) result = Falcon.private_compare(x_p, r.type(tensor_type)) expected_res = torch.tensor([[1, 0], [0, 1]], dtype=torch.bool) assert (result.reconstruct(decode=False) == expected_res).all()
def test_bit_injection_prime(get_clients, security_type) -> None: parties = get_clients(3) falcon = Falcon(security_type=security_type) session = Session(parties=parties, protocol=falcon) SessionManager.setup_mpc(session) ring_size = PRIME_NUMBER bin_sh = torch.tensor([[1, 1], [0, 0]], dtype=torch.bool) shares = [bin_sh, bin_sh, bin_sh] ptr_lst = ReplicatedSharedTensor.distribute_shares(shares, session, ring_size=2) x = MPCTensor(shares=ptr_lst, session=session, shape=bin_sh.shape) xbit = ABY3.bit_injection(x, session, ring_size) ring0 = int(xbit.share_ptrs[0].get_ring_size().get_copy()) result = xbit.reconstruct(decode=False) exp_res = bin_sh.type(torch.uint8) assert (result == exp_res).all() assert ring_size == ring0
def test_hook_method(get_clients) -> None: clients = get_clients(3) session = Session(parties=clients) SessionManager.setup_mpc(session) x = torch.randn(1, 3) y = torch.randn(1, 3) shares = [x, y] rst = ReplicatedSharedTensor() rst.shares = shares assert rst.numel() == x.numel() assert (rst.t().shares[0] == x.t()).all() assert (rst.unsqueeze(dim=0).shares[0] == x.unsqueeze(dim=0)).all() assert (rst.view(3, 1).shares[0] == x.view(3, 1)).all() assert (rst.sum().shares[0] == x.sum()).all() assert rst.numel() == y.numel() assert (rst.t().shares[1] == y.t()).all() assert (rst.unsqueeze(dim=0).shares[1] == y.unsqueeze(dim=0)).all() assert (rst.view(3, 1).shares[1] == y.view(3, 1)).all() assert (rst.sum().shares[1] == y.sum()).all()
def test_backward_with_one_requires_grad(get_clients): clients = get_clients(4) session = Session(parties=clients) session.autograd_active = True SessionManager.setup_mpc(session) x_secret = torch.tensor([[0.125, -1.25], [-4.25, 4], [-3, 3]], requires_grad=True) y_secret = torch.tensor([[4.5, -2.5], [5, 2.25], [-3, 3]]) x = MPCTensor(secret=x_secret, session=session, requires_grad=True) y = MPCTensor(secret=y_secret, session=session) res_mpc = x - y res = x_secret - y_secret s_mpc = res_mpc.sum() s = torch.sum(res) s_mpc.backward() s.backward() # TODO: add assert for res_mpc.grad and res.grad assert res_mpc.requires_grad assert np.allclose(x.grad.get(), x_secret.grad, rtol=1e-3) assert y.grad is None
def test_primitive_logging_beaver_conv2d(get_clients) -> None: clients = get_clients(2) session = Session(parties=clients) SessionManager.setup_mpc(session) p_kwargs = {"a_shape": (1, 1, 28, 28), "b_shape": (5, 1, 5, 5)} g_kwargs = { "a_shape": (1, 1, 28, 28), "b_shape": (5, 1, 5, 5), "nr_parties": 2 } CryptoPrimitiveProvider.start_logging() CryptoPrimitiveProvider.generate_primitives( sessions=session.session_ptrs, op_str="beaver_conv2d", p_kwargs=p_kwargs, g_kwargs=g_kwargs, ) primitive_log = CryptoPrimitiveProvider.stop_logging() expected_log = {"beaver_conv2d": [(p_kwargs, g_kwargs)]} assert expected_log == primitive_log
def test_ops_bin_public_xor(get_clients, security, bit) -> None: parties = get_clients(3) protocol = Falcon(security) session = Session(protocol=protocol, parties=parties) SessionManager.setup_mpc(session) ring_size = 2 sh = torch.tensor([[0, 1, 0], [1, 0, 1]], dtype=torch.bool) shares = [sh, sh, sh] rst_list = ReplicatedSharedTensor.distribute_shares(shares=shares, session=session, ring_size=ring_size) tensor = MPCTensor(shares=rst_list, session=session) tensor.shape = sh.shape secret = ReplicatedSharedTensor.shares_sum(shares, ring_size) value = torch.tensor([bit], dtype=torch.bool) result = operator.xor(tensor, value) expected_res = secret ^ value assert (result.reconstruct(decode=False) == expected_res).all()