def test_private_compare_exception(get_clients) -> None: parties = get_clients(3) falcon = Falcon() session = Session(parties=parties, protocol=falcon) SessionManager.setup_mpc(session) x = MPCTensor(secret=1, session=session) r = torch.tensor([1]) # Expection for not passing input tensor values as list. with pytest.raises(ValueError): Falcon.private_compare(x, r) # Exception for not passing a public value(torch.Tensor). with pytest.raises(ValueError): Falcon.private_compare([x], x)
def test_private_compare(get_clients, security) -> None: parties = get_clients(3) falcon = Falcon(security_type=security) session = Session(parties=parties, protocol=falcon) SessionManager.setup_mpc(session) base = session.config.encoder_base precision = session.config.encoder_precision fp_encoder = FixedPointEncoder(base=base, precision=precision) secret = torch.tensor([[358.85, 79.29], [67.78, 2415.50]]) r = torch.tensor([[357.05, 90], [145.32, 2400.54]]) r = fp_encoder.encode(r) x = MPCTensor(secret=secret, session=session) x_b = ABY3.bit_decomposition_ttp(x, session) # bit shares x_p = [] # prime ring shares for share in x_b: x_p.append(ABY3.bit_injection(share, session, PRIME_NUMBER)) tensor_type = get_type_from_ring(session.ring_size) result = Falcon.private_compare(x_p, r.type(tensor_type)) expected_res = torch.tensor([[1, 0], [0, 1]], dtype=torch.bool) assert (result.reconstruct(decode=False) == expected_res).all()