def test_invalid_mpc_pointer(get_clients) -> None: parties = get_clients(3) session = Session(parties=parties) SessionManager.setup_mpc(session) x = MPCTensor(secret=1, session=session) # passing sharetensor pointer with pytest.raises(ValueError): ABY3.truncate(x, session, 2**32, None)
def test_bit_injection_exception(get_clients) -> None: parties = get_clients(3) falcon = Falcon() session = Session(parties=parties, protocol=falcon) SessionManager.setup_mpc(session) x = MPCTensor(secret=1, session=session) with pytest.raises(ValueError): ABY3.bit_injection(x, session, 2**64)
def test_eq(): falcon = Falcon() aby1 = ABY3(security_type="malicious") aby2 = ABY3() other2 = falcon # Test equal protocol: assert falcon == other2 # Test different protocol security type assert falcon != aby1 # Test different protocol objects assert falcon != aby2
def private_compare(x: List[MPCTensor], r: torch.Tensor) -> MPCTensor: """Falcon Private Compare functionality which computes(x>r). Args: x (List[MPCTensor]) : shares of bits of x in Zp. r (torch.Tensor) : Public value r. Returns: result (MPCTensor): Returns shares of bits of the operation. Raises: ValueError: If input shares is not a list. ValueError: If input public value is not a tensor. (if (x>=r) returns 1 else returns 0) """ if not isinstance(x, list): raise ValueError( f"Input shares for Private Compare: {x} must be a list") if not isinstance(r, torch.Tensor): raise ValueError( f"Value r:{r} must be a torch tensor for private compare") shape = x[0].shape session = x[0].session ptr_list: List[ReplicatedSharedTensor] = [ session_ptr.prrs_generate_random_share(shape=shape, ring_size="2") for session_ptr in session.session_ptrs ] beta_2 = MPCTensor(shares=ptr_list, session=session, shape=shape) # shares of random bit beta_p = ABY3.bit_injection( beta_2, session, PRIME_NUMBER) # shares of random bit in Zp. m = Falcon._random_prime_group(session, shape) nr_shares = len(x) u = [0] * nr_shares c = [0] * nr_shares w = 0 for i in range(len(x) - 1, -1, -1): r_i = (r >> i) & 1 # bit at ith position u[i] = (1 - 2 * beta_p) * (x[i] - r_i) c[i] = u[i] + 1 + w w += x[i] ^ r_i d = m * math.prod(c) d_val = d.reconstruct(decode=False) # plaintext d. d_val[d_val != 0] = 1 # making all non zero values as 1. beta_prime = d_val return beta_2 + beta_prime
def test_private_compare(get_clients, security) -> None: parties = get_clients(3) falcon = Falcon(security_type=security) session = Session(parties=parties, protocol=falcon) SessionManager.setup_mpc(session) base = session.config.encoder_base precision = session.config.encoder_precision fp_encoder = FixedPointEncoder(base=base, precision=precision) secret = torch.tensor([[358.85, 79.29], [67.78, 2415.50]]) r = torch.tensor([[357.05, 90], [145.32, 2400.54]]) r = fp_encoder.encode(r) x = MPCTensor(secret=secret, session=session) x_b = ABY3.bit_decomposition_ttp(x, session) # bit shares x_p = [] # prime ring shares for share in x_b: x_p.append(ABY3.bit_injection(share, session, PRIME_NUMBER)) tensor_type = get_type_from_ring(session.ring_size) result = Falcon.private_compare(x_p, r.type(tensor_type)) expected_res = torch.tensor([[1, 0], [0, 1]], dtype=torch.bool) assert (result.reconstruct(decode=False) == expected_res).all()
def test_eq(): aby = ABY3() falcon1 = Falcon(security_type="malicious") falcon2 = Falcon() other2 = aby # Test equal protocol: assert aby == other2 # Test different protocol security type assert aby != falcon1 # Test different protocol objects assert aby != falcon2
def select_shares(x: MPCTensor, y: MPCTensor, b: MPCTensor) -> MPCTensor: """Returns either x or y based on bit b. Args: x (MPCTensor): input tensor y (MPCTensor): input tensor b (MPCTensor): input tensor which is shares of a bit used as selector bit. Returns: z (MPCTensor):Returns x (if b==0) or y (if b==1). Raises: ValueError: If the selector bit tensor is not of ring size "2". """ ring_size = int(b.share_ptrs[0].get_ring_size().get_copy()) shape = b.shape if ring_size != 2: raise ValueError( f"Invalid {ring_size} for selector bit,must be of ring size 2") if shape is None: raise ValueError( "The selector bit tensor must have a valid shape.") session = x.session # TODO: Should be made to generate with CryptoProvider in Preprocessing stage. c_ptrs: List[ReplicatedSharedTensor] = [] for session_ptr in session.session_ptrs: c_ptrs.append( session_ptr.prrs_generate_random_share( shape=shape, ring_size=str(ring_size))) c = MPCTensor(shares=c_ptrs, session=session, shape=shape) # bit random share c_r = ABY3.bit_injection( c, session, session.ring_size) # bit random share in session ring. tensor_type = get_type_from_ring(session.ring_size) mask = (b ^ c).reconstruct(decode=False).type(tensor_type) d = (mask - (c_r * mask)) + (c_r * (mask ^ 1)) # Order placed carefully to prevent re-encoding,should not be changed. z = x + (d * (y - x)) return z
def mul_master( x: MPCTensor, y: MPCTensor, session: Session, op_str: str, kwargs_: Dict[Any, Any], ) -> MPCTensor: """Master method for multiplication. Args: x (MPCTensor): Secret y (MPCTensor): Another secret session (Session): Session the tensors belong to op_str (str): Operation string. kwargs_ (Dict[Any, Any]): Kwargs for some operations like conv2d Returns: result (MPCTensor): Result of the operation. Raises: ValueError: Raised when number of parties are not three. ValueError : Raised when invalid security_type is provided. """ if len(session.parties) != 3: raise ValueError("Falcon requires 3 parties") result = None ring_size = int(x.share_ptrs[0].get_ring_size().get_copy()) conf_dict = x.share_ptrs[0].get_config().get_copy() config = Config(**conf_dict) if session.protocol.security_type == "semi-honest": result = Falcon.mul_semi_honest(x, y, session, op_str, ring_size, config, **kwargs_) elif session.protocol.security_type == "malicious": result = Falcon.mul_malicious(x, y, session, op_str, ring_size, config, **kwargs_) else: raise ValueError("Invalid security_type for Falcon multiplication") result = ABY3.truncate(result, session, ring_size, config) return result
def test_truncation_algorithm1(get_clients, base, precision) -> None: parties = get_clients(3) falcon = Falcon("semi-honest") config = Config(encoder_base=base, encoder_precision=precision) session = Session(parties=parties, protocol=falcon, config=config) SessionManager.setup_mpc(session) x = torch.tensor([[1.24, 4.51, 6.87], [7.87, 1301, 541]]) x_mpc = MPCTensor(secret=x, session=session) result = ABY3.truncate(x_mpc, session, session.ring_size, session.config) fp_encoder = FixedPointEncoder(base=session.config.encoder_base, precision=session.config.encoder_precision) expected_res = x_mpc.reconstruct(decode=False) // fp_encoder.scale expected_res = fp_encoder.decode(expected_res) assert np.allclose(result.reconstruct(), expected_res, atol=1e-3)
def test_bit_decomposition_ttp(get_clients, security_type) -> None: parties = get_clients(3) falcon = Falcon(security_type=security_type) session = Session(parties=parties, protocol=falcon) SessionManager.setup_mpc(session) secret = torch.tensor([[-1, 12], [-32, 45], [98, -5624]]) x = MPCTensor(secret=secret, session=session) b_sh = ABY3.bit_decomposition_ttp(x, session) ring_size = x.session.ring_size tensor_type = x.session.tensor_type ring_bits = get_nr_bits(ring_size) result = torch.zeros(size=x.shape, dtype=tensor_type) for i in range(ring_bits): result |= b_sh[i].reconstruct(decode=False).type(tensor_type) << i exp_res = torch.tensor([[-65536, 786432], [-2097152, 2949120], [6422528, -368574464]]) assert (result == exp_res).all()
def test_bit_injection_prime(get_clients, security_type) -> None: parties = get_clients(3) falcon = Falcon(security_type=security_type) session = Session(parties=parties, protocol=falcon) SessionManager.setup_mpc(session) ring_size = PRIME_NUMBER bin_sh = torch.tensor([[1, 1], [0, 0]], dtype=torch.bool) shares = [bin_sh, bin_sh, bin_sh] ptr_lst = ReplicatedSharedTensor.distribute_shares(shares, session, ring_size=2) x = MPCTensor(shares=ptr_lst, session=session, shape=bin_sh.shape) xbit = ABY3.bit_injection(x, session, ring_size) ring0 = int(xbit.share_ptrs[0].get_ring_size().get_copy()) result = xbit.reconstruct(decode=False) exp_res = bin_sh.type(torch.uint8) assert (result == exp_res).all() assert ring_size == ring0
def truncate(self, input_tensor: "MPCTensor", op_str: str, is_private: bool) -> "MPCTensor": """Checks if operation requires truncation and performs it if required. Args: input_tensor (MPCTensor): Result of operation op_str (str): Operation name is_private (bool): If operation is private Returns: result (MPCTensor): Truncated result """ from sympc.protocol import ABY3 from sympc.tensor import ReplicatedSharedTensor result = None if (op_str in TRUNCATED_OPS and (not is_private or self.session.nr_parties > 2) and self.session.protocol.share_class == ShareTensor): # For private op we do the division in the mul_parties function from spdz scale = (self.session.config.encoder_base** self.session.config.encoder_precision) result = input_tensor.truediv(scale) elif (op_str in TRUNCATED_OPS and (not is_private) and self.session.protocol.share_class == ReplicatedSharedTensor): ring_size = int(self.share_ptrs[0].get_ring_size().get_copy()) conf_dict = self.share_ptrs[0].get_config().get_copy() config = Config(**conf_dict) result = ABY3.truncate(input_tensor, self.session, ring_size, config) else: result = input_tensor return result
def test_invalid_parties_trunc(get_clients) -> None: parties = get_clients(2) session = Session(parties=parties) with pytest.raises(ValueError): ABY3.truncate(None, session, 2**32, None)
def test_invalid_security_type(): with pytest.raises(ValueError): ABY3(security_type="covert")
def test_session() -> None: protocol = ABY3("semi-honest") session = Session(protocol=protocol) assert type(session.protocol) == ABY3
def test_local_decomposition_exception() -> None: x = ReplicatedSharedTensor() with pytest.raises(ValueError): ABY3.local_decomposition(x, "2")