def test_determine_sign(workers): bob, alice, james = (workers["bob"], workers["alice"], workers["james"]) workers = [bob, alice, james] x = torch.tensor([-4, 5, 6]).share(*workers, protocol="falcon") beta_0 = torch.tensor(0).share(*workers, protocol="falcon", field=2) beta_1 = torch.tensor(1).share(*workers, protocol="falcon", field=2) assert (FalconHelper.determine_sign( x, beta_0).reconstruct() == torch.tensor([-4, 5, 6])).all() assert (FalconHelper.determine_sign( x, beta_1).reconstruct() == (-1 * torch.tensor([-4, 5, 6]))).all()
def test_determine_sign(beta, workers): bob, alice, james = (workers["bob"], workers["alice"], workers["james"]) workers = [bob, alice, james] x = torch.tensor([-4, 5, 6]) x_shared = x.share(*workers, protocol="falcon") ring_size = x_shared.ring_size shape = x_shared.shape expected_plaintext = (-1)**beta * x if beta: beta = torch.ones(size=shape, dtype=torch.long).share(*workers, protocol="falcon", field=ring_size) else: beta = torch.zeros(size=shape, dtype=torch.long).share(*workers, protocol="falcon", field=ring_size) plaintext = FalconHelper.determine_sign(x_shared, beta).reconstruct() assert (expected_plaintext == plaintext).all()