def test_bit_decomposition_ttp(get_clients, security_type) -> None: parties = get_clients(3) falcon = Falcon(security_type=security_type) session = Session(parties=parties, protocol=falcon) SessionManager.setup_mpc(session) secret = torch.tensor([[-1, 12], [-32, 45], [98, -5624]]) x = MPCTensor(secret=secret, session=session) b_sh = ABY3.bit_decomposition_ttp(x, session) ring_size = x.session.ring_size tensor_type = x.session.tensor_type ring_bits = get_nr_bits(ring_size) result = torch.zeros(size=x.shape, dtype=tensor_type) for i in range(ring_bits): result |= b_sh[i].reconstruct(decode=False).type(tensor_type) << i exp_res = torch.tensor([[-65536, 786432], [-2097152, 2949120], [6422528, -368574464]]) assert (result == exp_res).all()
def test_private_compare(get_clients, security) -> None: parties = get_clients(3) falcon = Falcon(security_type=security) session = Session(parties=parties, protocol=falcon) SessionManager.setup_mpc(session) base = session.config.encoder_base precision = session.config.encoder_precision fp_encoder = FixedPointEncoder(base=base, precision=precision) secret = torch.tensor([[358.85, 79.29], [67.78, 2415.50]]) r = torch.tensor([[357.05, 90], [145.32, 2400.54]]) r = fp_encoder.encode(r) x = MPCTensor(secret=secret, session=session) x_b = ABY3.bit_decomposition_ttp(x, session) # bit shares x_p = [] # prime ring shares for share in x_b: x_p.append(ABY3.bit_injection(share, session, PRIME_NUMBER)) tensor_type = get_type_from_ring(session.ring_size) result = Falcon.private_compare(x_p, r.type(tensor_type)) expected_res = torch.tensor([[1, 0], [0, 1]], dtype=torch.bool) assert (result.reconstruct(decode=False) == expected_res).all()