Exemplo n.º 1
0
def test_mse_loss(get_clients, reduction) -> None:
    clients = get_clients(4)
    session = Session(parties=clients)
    SessionManager.setup_mpc(session)

    y_secret = torch.Tensor([0.23, 0.32, 0.2, 0.3])
    y_mpc = MPCTensor(secret=y_secret, session=session)

    y_pred = torch.Tensor([0.1, 0.3, 0.4, 0.2])
    y_pred_mpc = MPCTensor(secret=y_pred, session=session)

    res = mse_loss(y_mpc, y_pred_mpc, reduction)
    res_expected = torch.nn.functional.mse_loss(y_secret,
                                                y_pred,
                                                reduction=reduction)

    assert np.allclose(res.reconstruct(), res_expected, atol=1e-4)
Exemplo n.º 2
0
def test_session_manager_init():
    """Test correct initialisation of the SessionManager class."""
    # Test default init
    session = SessionManager()
    assert isinstance(session.uuid, UUID)
    # Test custom init
    uuid = uuid4()
    session = Session(uuid=uuid)
    assert session.uuid == uuid
Exemplo n.º 3
0
def test_prrs_rst_ring_size(get_clients) -> None:
    clients = get_clients(3)
    falcon = Falcon()
    session = Session(protocol=falcon, parties=clients)
    SessionManager.setup_mpc(session)

    for ring_size in RING_SIZE_TO_TYPE.keys():
        rst_pt0 = session.session_ptrs[0].prrs_generate_random_share(
            shape=(1, 2), ring_size=str(ring_size))
        share = rst_pt0.get_copy()

        assert share.ring_size == ring_size
        assert share.shares[0].dtype == RING_SIZE_TO_TYPE[ring_size]
        assert share.shares[1].dtype == RING_SIZE_TO_TYPE[ring_size]

        if ring_size == PRIME_NUMBER:
            assert torch.max(torch.cat(share.shares)) <= PRIME_NUMBER - 1
            assert torch.min(torch.cat(share.shares)) >= 0
Exemplo n.º 4
0
def test_max_dim(dim, keepdim, get_clients) -> None:
    clients = get_clients(2)
    session = Session(parties=clients)
    SessionManager.setup_mpc(session)

    secret = torch.Tensor([[[1, 2], [3, -1], [4, 5]], [[2, 5], [5, 1], [6, 42]]])
    x = MPCTensor(secret=secret, session=session)

    max_val, max_idx_val = x.max(dim=dim, keepdim=keepdim)
    assert isinstance(x, MPCTensor), "Expected argmax to be MPCTensor"

    res_idx = max_idx_val.reconstruct()
    res_max = max_val.reconstruct()
    expected_max, expected_indices = secret.max(dim=dim, keepdim=keepdim)
    assert (
        res_idx == expected_indices
    ).all(), f"Expected indices for maximum to be {expected_indices}"
    assert (res_max == expected_max).all(), f"Expected argmax to be {expected_max}"
Exemplo n.º 5
0
def test_ops_mpc_public(get_clients, nr_clients, op_str) -> None:
    clients = get_clients(nr_clients)
    session = Session(parties=clients)
    SessionManager.setup_mpc(session)

    x_secret = torch.Tensor([[0.125, -1.25], [-4.25, 4]])

    if op_str == "truediv":
        y_secret = torch.Tensor([[2, 3], [4, 5]]).long()
    else:
        y_secret = torch.Tensor([[4.5, -2.5], [5, 2.25]])

    x = MPCTensor(secret=x_secret, session=session)

    op = getattr(operator, op_str)
    expected_result = op(x_secret, y_secret)
    result = op(x, y_secret).reconstruct()
    assert np.allclose(result, expected_result, atol=10e-4)
def test_ops_public_mul_integer_parties(get_clients, parties, security):

    config = Config(encoder_base=1, encoder_precision=0)

    parties = get_clients(parties)
    protocol = Falcon(security)
    session = Session(protocol=protocol, parties=parties, config=config)
    SessionManager.setup_mpc(session)

    secret = torch.tensor([[-100, 20, 30], [-90, 1000, 1], [1032, -323, 15]])
    value = 8

    op = getattr(operator, "mul")
    tensor = MPCTensor(secret=secret, session=session)
    shares = [op(share, value) for share in tensor.share_ptrs]
    result = MPCTensor(shares=shares, session=session)

    assert (result.reconstruct() == (secret * value)).all()
Exemplo n.º 7
0
def test_conv_mpc_mpc(get_clients, nr_clients, bias, stride, padding, op_str) -> None:
    clients = get_clients(nr_clients)
    session = Session(parties=clients)
    SessionManager.setup_mpc(session)

    input_secret = torch.ones(1, 1, 4, 4)
    weight_secret = torch.ones(1, 1, 2, 2)
    input = MPCTensor(secret=input_secret, session=session)
    weight = MPCTensor(secret=weight_secret, session=session)

    kwargs = {"bias": bias, "stride": stride, "padding": padding}

    op = getattr(MPCTensor, op_str)
    result = op(input, weight, **kwargs).reconstruct()
    op = getattr(torch.nn.functional, op_str)
    expected_result = op(input_secret, weight_secret, **kwargs)

    assert np.allclose(result, expected_result, rtol=10e-4)
Exemplo n.º 8
0
def test_mul_private_matrix(get_clients, security, base, precision):
    parties = get_clients(3)
    protocol = Falcon(security)
    config = Config(encoder_base=base, encoder_precision=precision)
    session = Session(protocol=protocol, parties=parties, config=config)
    SessionManager.setup_mpc(session)

    secret1 = torch.tensor([[-100.25, 20.3, 30.12], [-50.1, 100.217, 1.2],
                            [1032.15, -323.56, 15.15]])

    secret2 = torch.tensor([[-1, 0.28, 3], [-9, 10.18, 1], [32, -23, 5]])

    tensor1 = MPCTensor(secret=secret1, session=session)
    tensor2 = MPCTensor(secret=secret2, session=session)

    result = tensor1 * tensor2
    expected_res = secret1 * secret2
    assert np.allclose(result.reconstruct(), expected_res, atol=1e-3)
Exemplo n.º 9
0
def test_eq() -> None:
    """Test __eq__ for Session."""
    session = Session()
    other1 = Session()
    other2 = session

    # Test different instances:
    assert session != 1

    # Test equal sessions:
    assert session == other2

    # Test same sessions (until we call setup mpc):
    assert session == other1

    SessionManager.setup_mpc(session)

    assert session != other1
Exemplo n.º 10
0
def test_reconstruct(get_clients) -> None:
    clients = get_clients(2)
    session = Session(parties=clients)
    SessionManager.setup_mpc(session)

    a_rand = 3
    a = ShareTensor(data=a_rand, config=Config(encoder_precision=0))
    MPCTensor.generate_shares(secret=a, nr_parties=2, tensor_type=torch.long)

    MPCTensor.generate_shares(
        secret=a_rand, nr_parties=2, config=Config(), tensor_type=torch.long
    )

    x_secret = torch.Tensor([1, -2, 3.0907, -4.870])
    x = MPCTensor(secret=x_secret, session=session)
    x = x.reconstruct()

    assert np.allclose(x_secret, x)
Exemplo n.º 11
0
def test_bit_decomposition_ttp(get_clients, security_type) -> None:
    parties = get_clients(3)
    falcon = Falcon(security_type=security_type)
    session = Session(parties=parties, protocol=falcon)
    SessionManager.setup_mpc(session)
    secret = torch.tensor([[-1, 12], [-32, 45], [98, -5624]])
    x = MPCTensor(secret=secret, session=session)
    b_sh = ABY3.bit_decomposition_ttp(x, session)
    ring_size = x.session.ring_size
    tensor_type = x.session.tensor_type
    ring_bits = get_nr_bits(ring_size)

    result = torch.zeros(size=x.shape, dtype=tensor_type)
    for i in range(ring_bits):
        result |= b_sh[i].reconstruct(decode=False).type(tensor_type) << i

    exp_res = torch.tensor([[-65536, 786432], [-2097152, 2949120],
                            [6422528, -368574464]])
    assert (result == exp_res).all()
Exemplo n.º 12
0
def test_truncation_algorithm1(get_clients, base, precision) -> None:
    parties = get_clients(3)
    falcon = Falcon("semi-honest")
    config = Config(encoder_base=base, encoder_precision=precision)
    session = Session(parties=parties, protocol=falcon, config=config)
    SessionManager.setup_mpc(session)

    x = torch.tensor([[1.24, 4.51, 6.87], [7.87, 1301, 541]])

    x_mpc = MPCTensor(secret=x, session=session)

    result = ABY3.truncate(x_mpc, session, session.ring_size, session.config)

    fp_encoder = FixedPointEncoder(base=session.config.encoder_base,
                                   precision=session.config.encoder_precision)
    expected_res = x_mpc.reconstruct(decode=False) // fp_encoder.scale
    expected_res = fp_encoder.decode(expected_res)

    assert np.allclose(result.reconstruct(), expected_res, atol=1e-3)
def test_send_get(get_clients, precision=12, base=4) -> None:
    client = get_clients(1)[0]
    protocol = Falcon("semi-honest")
    session = Session(protocol=protocol, parties=[client])
    SessionManager.setup_mpc(session)

    share1 = torch.Tensor([1.4, 2.34, 3.43])
    share2 = torch.Tensor([1, 2, 3])
    share3 = torch.Tensor([1.4, 2.34, 3.43])

    session_uuid = session.rank_to_uuid[0]

    x_share = ReplicatedSharedTensor(shares=[share1, share2, share3],
                                     session_uuid=session_uuid)

    x_ptr = x_share.send(client)
    result = x_ptr.get()

    assert result == x_share
def test_primitive_logging_beaver_matmul(get_clients) -> None:
    clients = get_clients(2)
    session = Session(parties=clients)
    SessionManager.setup_mpc(session)

    p_kwargs = {"a_shape": (2, 3), "b_shape": (3, 10)}
    g_kwargs = {"a_shape": (2, 3), "b_shape": (3, 10), "nr_parties": 2}

    CryptoPrimitiveProvider.start_logging()
    CryptoPrimitiveProvider.generate_primitives(
        session=session,
        op_str="beaver_matmul",
        p_kwargs=p_kwargs,
        g_kwargs=g_kwargs,
    )
    primitive_log = CryptoPrimitiveProvider.stop_logging()
    expected_log = {"beaver_matmul": [(p_kwargs, g_kwargs)]}

    assert expected_log == primitive_log
Exemplo n.º 15
0
def test_backward_without_requires_grad(get_clients):
    clients = get_clients(4)
    session = Session(parties=clients)
    session.autograd_active = True
    SessionManager.setup_mpc(session)

    x_secret = torch.tensor([[0.125, -1.25], [-4.25, 4], [-3, 3]])
    y_secret = torch.tensor([[4.5, -2.5], [5, 2.25], [-3, 3]])
    x = MPCTensor(secret=x_secret, session=session)
    y = MPCTensor(secret=y_secret, session=session)

    res_mpc = x - y
    s_mpc = res_mpc.sum()
    s_mpc.backward()

    assert not res_mpc.requires_grad
    assert res_mpc.grad is None
    assert x.grad is None
    assert y.grad is None
Exemplo n.º 16
0
def test_prrs_share_tensor() -> None:
    """Test przs_generate_random_share method from Session for ShareTensor."""
    session = Session()  # default protocol: FSS
    SessionManager.setup_mpc(session)
    seed1 = secrets.randbits(32)
    seed2 = secrets.randbits(32)
    gen1 = get_new_generator(seed1)
    gen2 = get_new_generator(seed2)
    session.przs_generators = [gen1, gen2]
    shape = (2, 1)
    share = session.prrs_generate_random_share(shape=shape)
    assert isinstance(share, ShareTensor)

    new_gen1 = get_new_generator(seed1)
    share1 = generate_random_element(generator=new_gen1,
                                     shape=shape,
                                     tensor_type=session.tensor_type)
    target_tensor = share1
    assert (share.tensor == target_tensor).all()
Exemplo n.º 17
0
def test_generate_primitive_from_dict_beaver_conv2d(get_clients) -> None:
    clients = get_clients(2)
    session = Session(parties=clients)
    SessionManager.setup_mpc(session)

    primitive_log = {
        "beaver_conv2d": [(
            {
                "a_shape": (1, 1, 28, 28),
                "b_shape": (5, 1, 5, 5)
            },
            {
                "session": session,
                "a_shape": (1, 1, 28, 28),
                "b_shape": (5, 1, 5, 5),
                "nr_parties": 2,
            },
        )]
    }

    CryptoPrimitiveProvider.generate_primitive_from_dict(
        primitive_log=primitive_log, session=session)

    a_shape = (1, 1, 28, 28)
    b_shape = (5, 1, 5, 5)

    key = f"beaver_conv2d_{a_shape}_{b_shape}"

    store_client_1 = session.session_ptrs[0].crypto_store.store.get()
    store_client_2 = session.session_ptrs[1].crypto_store.store.get()

    a_shape_client_1 = tuple(store_client_1.get(key)[0][0].shape)
    b_shape_client_1 = tuple(store_client_1.get(key)[0][1].shape)

    assert a_shape == a_shape_client_1
    assert b_shape == b_shape_client_1

    a_shape_client_2 = tuple(store_client_2.get(key)[0][0].shape)
    b_shape_client_2 = tuple(store_client_2.get(key)[0][1].shape)

    assert a_shape == a_shape_client_2
    assert b_shape == b_shape_client_2
def test_generate_primitive(get_clients: Callable, nr_parties: int,
                            nr_instances: int) -> None:
    parties = get_clients(nr_parties)
    session = Session(parties=parties)
    SessionManager.setup_mpc(session)

    g_kwargs = {"nr_parties": nr_parties, "nr_instances": nr_instances}
    res = CryptoPrimitiveProvider.generate_primitives(
        "test",
        sessions=session.session_ptrs,
        g_kwargs=g_kwargs,
        p_kwargs=None,
    )

    assert isinstance(res, list)
    assert len(res) == nr_parties

    for i, primitives in enumerate(res):
        for primitive in primitives:
            assert primitive == tuple(i for _ in range(PRIMITIVE_NR_ELEMS))
Exemplo n.º 19
0
def test_generate_and_transfer_primitive(
    get_clients: Callable, nr_parties: int, nr_instances: int
) -> None:
    parties = get_clients(nr_parties)
    session = Session(parties=parties)
    SessionManager.setup_mpc(session)

    g_kwargs = {"nr_parties": nr_parties}
    CryptoPrimitiveProvider.generate_primitives(
        "test",
        n_instances=nr_instances,
        sessions=session.session_ptrs,
        g_kwargs=g_kwargs,
        p_kwargs={},
    )

    for i in range(nr_parties):
        remote_crypto_store = session.session_ptrs[i].crypto_store
        primitives = remote_crypto_store.get_primitives_from_store("test").get()
        assert primitives == [tuple(i for _ in range(PRIMITIVE_NR_ELEMS))]
def test_generate_primitive_from_dict_beaver_matmul(get_clients) -> None:
    clients = get_clients(2)
    session = Session(parties=clients)
    SessionManager.setup_mpc(session)

    primitive_log = {
        "beaver_matmul": [(
            {
                "a_shape": (2, 3),
                "b_shape": (3, 10)
            },
            {
                "a_shape": (2, 3),
                "b_shape": (3, 10),
                "nr_parties": 2
            },
        )]
    }

    CryptoPrimitiveProvider.generate_primitive_from_dict(
        primitive_log=primitive_log, session=session)

    a_shape = (2, 3)
    b_shape = (3, 10)

    key = f"beaver_matmul_{a_shape}_{b_shape}"

    store_client_1 = session.session_ptrs[0].crypto_store.store.get()
    store_client_2 = session.session_ptrs[1].crypto_store.store.get()

    a_shape_client_1 = tuple(store_client_1.get(key)[0][0].shape)
    b_shape_client_1 = tuple(store_client_1.get(key)[0][1].shape)

    assert a_shape == a_shape_client_1
    assert b_shape == b_shape_client_1

    a_shape_client_2 = tuple(store_client_2.get(key)[0][0].shape)
    b_shape_client_2 = tuple(store_client_2.get(key)[0][1].shape)

    assert a_shape == a_shape_client_2
    assert b_shape == b_shape_client_2
Exemplo n.º 21
0
def test_load_sympc() -> None:
    alice = sy.VirtualMachine()
    alice_client = alice.get_root_client()
    bob = sy.VirtualMachine()
    bob_client = bob.get_root_client()

    # third party
    from sympc.session import Session
    from sympc.session import SessionManager
    from sympc.tensor import MPCTensor

    sy.load("sympc")

    session = Session(parties=[alice_client, bob_client])
    SessionManager.setup_mpc(session)

    y = th.Tensor([-5, 0, 1, 2, 3])
    x_secret = th.Tensor([30])
    x = MPCTensor(secret=x_secret, shape=(1,), session=session)

    assert ((x + y).reconstruct() == th.Tensor([25.0, 30.0, 31.0, 32.0, 33.0])).all()
Exemplo n.º 22
0
def test_prime_xor(get_clients, security, bit) -> None:
    parties = get_clients(3)
    protocol = Falcon(security)
    session = Session(protocol=protocol, parties=parties)
    session.ring_size = PRIME_NUMBER
    SessionManager.setup_mpc(session)
    ring_size = PRIME_NUMBER

    x_sh1 = torch.tensor([[17, 44], [8, 20]], dtype=torch.uint8)
    x_sh2 = torch.tensor([[8, 51], [27, 52]], dtype=torch.uint8)
    x_sh3 = torch.tensor([[42, 40], [32, 63]], dtype=torch.uint8)

    bit_sh_1, bit_sh_2, bit_sh_3 = bit

    b_sh1 = torch.tensor([bit_sh_1], dtype=torch.uint8)
    b_sh2 = torch.tensor([bit_sh_2], dtype=torch.uint8)
    b_sh3 = torch.tensor([bit_sh_3], dtype=torch.uint8)

    shares_x = [x_sh1, x_sh2, x_sh3]
    shares_b = [b_sh1, b_sh2, b_sh3]

    rst_list_x = ReplicatedSharedTensor.distribute_shares(shares=shares_x,
                                                          session=session,
                                                          ring_size=ring_size)
    rst_list_b = ReplicatedSharedTensor.distribute_shares(shares=shares_b,
                                                          session=session,
                                                          ring_size=ring_size)

    x = MPCTensor(shares=rst_list_x, session=session)
    b = MPCTensor(shares=rst_list_b, session=session)
    x.shape = x_sh1.shape
    b.shape = b_sh1.shape

    secret_x = ReplicatedSharedTensor.shares_sum(shares_x, ring_size)
    secret_b = ReplicatedSharedTensor.shares_sum(shares_b, ring_size)

    result = operator.xor(x, b)
    expected_res = secret_x ^ secret_b

    assert (result.reconstruct(decode=False) == expected_res).all()
Exemplo n.º 23
0
def test_backward(get_clients):
    clients = get_clients(4)
    session = Session(parties=clients)
    session.autograd_active = True
    SessionManager.setup_mpc(session)

    x_secret = torch.tensor([[0.125, -1.25], [-4.25, 4], [-3, 3]],
                            requires_grad=True)
    y_secret = torch.tensor([[4.5, -2.5], [5, 2.25], [-3, 3]],
                            requires_grad=True)
    x = MPCTensor(secret=x_secret, session=session, requires_grad=True)
    y = MPCTensor(secret=y_secret, session=session, requires_grad=True)

    res_mpc = x * y
    res = x_secret * y_secret
    s_mpc = res_mpc.sum()
    s = torch.sum(res)
    s_mpc.backward()
    s.backward()

    assert np.allclose(x.grad.get(), x_secret.grad, rtol=1e-3)
    assert np.allclose(y.grad.get(), y_secret.grad, rtol=1e-3)
Exemplo n.º 24
0
def test_select_shares(get_clients, security) -> None:
    parties = get_clients(3)
    falcon = Falcon(security)
    session = Session(parties=parties, protocol=falcon)
    SessionManager.setup_mpc(session)
    sh = torch.tensor([[1, 0], [0, 1]], dtype=torch.bool)
    shares = [sh, sh, sh]
    ptr_lst = ReplicatedSharedTensor.distribute_shares(shares,
                                                       session,
                                                       ring_size=2)
    b = MPCTensor(shares=ptr_lst, session=session, shape=sh.shape)

    x_val = torch.tensor([[1, 2], [3, 4]])
    y_val = torch.tensor([[5, 6], [7, 8]])
    x = MPCTensor(secret=x_val, session=session)
    y = MPCTensor(secret=y_val, session=session)

    z = Falcon.select_shares(x, y, b)

    expected_res = torch.tensor([[5.0, 2.0], [3.0, 8.0]])

    assert (expected_res == z.reconstruct()).all()
Exemplo n.º 25
0
def test_private_compare(get_clients, security) -> None:
    parties = get_clients(3)
    falcon = Falcon(security_type=security)
    session = Session(parties=parties, protocol=falcon)
    SessionManager.setup_mpc(session)
    base = session.config.encoder_base
    precision = session.config.encoder_precision
    fp_encoder = FixedPointEncoder(base=base, precision=precision)

    secret = torch.tensor([[358.85, 79.29], [67.78, 2415.50]])
    r = torch.tensor([[357.05, 90], [145.32, 2400.54]])
    r = fp_encoder.encode(r)
    x = MPCTensor(secret=secret, session=session)
    x_b = ABY3.bit_decomposition_ttp(x, session)  # bit shares
    x_p = []  # prime ring shares
    for share in x_b:
        x_p.append(ABY3.bit_injection(share, session, PRIME_NUMBER))

    tensor_type = get_type_from_ring(session.ring_size)
    result = Falcon.private_compare(x_p, r.type(tensor_type))
    expected_res = torch.tensor([[1, 0], [0, 1]], dtype=torch.bool)
    assert (result.reconstruct(decode=False) == expected_res).all()
Exemplo n.º 26
0
def test_bit_injection_prime(get_clients, security_type) -> None:
    parties = get_clients(3)
    falcon = Falcon(security_type=security_type)
    session = Session(parties=parties, protocol=falcon)
    SessionManager.setup_mpc(session)
    ring_size = PRIME_NUMBER

    bin_sh = torch.tensor([[1, 1], [0, 0]], dtype=torch.bool)

    shares = [bin_sh, bin_sh, bin_sh]
    ptr_lst = ReplicatedSharedTensor.distribute_shares(shares,
                                                       session,
                                                       ring_size=2)
    x = MPCTensor(shares=ptr_lst, session=session, shape=bin_sh.shape)

    xbit = ABY3.bit_injection(x, session, ring_size)

    ring0 = int(xbit.share_ptrs[0].get_ring_size().get_copy())
    result = xbit.reconstruct(decode=False)
    exp_res = bin_sh.type(torch.uint8)

    assert (result == exp_res).all()
    assert ring_size == ring0
Exemplo n.º 27
0
def test_hook_method(get_clients) -> None:
    clients = get_clients(3)
    session = Session(parties=clients)
    SessionManager.setup_mpc(session)

    x = torch.randn(1, 3)
    y = torch.randn(1, 3)
    shares = [x, y]

    rst = ReplicatedSharedTensor()
    rst.shares = shares

    assert rst.numel() == x.numel()
    assert (rst.t().shares[0] == x.t()).all()
    assert (rst.unsqueeze(dim=0).shares[0] == x.unsqueeze(dim=0)).all()
    assert (rst.view(3, 1).shares[0] == x.view(3, 1)).all()
    assert (rst.sum().shares[0] == x.sum()).all()

    assert rst.numel() == y.numel()
    assert (rst.t().shares[1] == y.t()).all()
    assert (rst.unsqueeze(dim=0).shares[1] == y.unsqueeze(dim=0)).all()
    assert (rst.view(3, 1).shares[1] == y.view(3, 1)).all()
    assert (rst.sum().shares[1] == y.sum()).all()
Exemplo n.º 28
0
def test_backward_with_one_requires_grad(get_clients):
    clients = get_clients(4)
    session = Session(parties=clients)
    session.autograd_active = True
    SessionManager.setup_mpc(session)

    x_secret = torch.tensor([[0.125, -1.25], [-4.25, 4], [-3, 3]],
                            requires_grad=True)
    y_secret = torch.tensor([[4.5, -2.5], [5, 2.25], [-3, 3]])
    x = MPCTensor(secret=x_secret, session=session, requires_grad=True)
    y = MPCTensor(secret=y_secret, session=session)

    res_mpc = x - y
    res = x_secret - y_secret
    s_mpc = res_mpc.sum()
    s = torch.sum(res)
    s_mpc.backward()
    s.backward()

    # TODO: add assert for res_mpc.grad and res.grad
    assert res_mpc.requires_grad
    assert np.allclose(x.grad.get(), x_secret.grad, rtol=1e-3)
    assert y.grad is None
def test_primitive_logging_beaver_conv2d(get_clients) -> None:
    clients = get_clients(2)
    session = Session(parties=clients)
    SessionManager.setup_mpc(session)

    p_kwargs = {"a_shape": (1, 1, 28, 28), "b_shape": (5, 1, 5, 5)}
    g_kwargs = {
        "a_shape": (1, 1, 28, 28),
        "b_shape": (5, 1, 5, 5),
        "nr_parties": 2
    }

    CryptoPrimitiveProvider.start_logging()
    CryptoPrimitiveProvider.generate_primitives(
        sessions=session.session_ptrs,
        op_str="beaver_conv2d",
        p_kwargs=p_kwargs,
        g_kwargs=g_kwargs,
    )
    primitive_log = CryptoPrimitiveProvider.stop_logging()
    expected_log = {"beaver_conv2d": [(p_kwargs, g_kwargs)]}

    assert expected_log == primitive_log
def test_ops_bin_public_xor(get_clients, security, bit) -> None:
    parties = get_clients(3)
    protocol = Falcon(security)
    session = Session(protocol=protocol, parties=parties)
    SessionManager.setup_mpc(session)
    ring_size = 2

    sh = torch.tensor([[0, 1, 0], [1, 0, 1]], dtype=torch.bool)
    shares = [sh, sh, sh]
    rst_list = ReplicatedSharedTensor.distribute_shares(shares=shares,
                                                        session=session,
                                                        ring_size=ring_size)
    tensor = MPCTensor(shares=rst_list, session=session)
    tensor.shape = sh.shape

    secret = ReplicatedSharedTensor.shares_sum(shares, ring_size)

    value = torch.tensor([bit], dtype=torch.bool)

    result = operator.xor(tensor, value)
    expected_res = secret ^ value

    assert (result.reconstruct(decode=False) == expected_res).all()