def test_session_ring_xor(get_clients, security, bit) -> None: parties = get_clients(3) protocol = Falcon(security) session = Session(protocol=protocol, parties=parties) SessionManager.setup_mpc(session) ring_size = session.ring_size tensor_type = session.tensor_type config = Config(encoder_base=1, encoder_precision=0) x_sh1 = torch.tensor([[927021, 3701]], dtype=tensor_type) x_sh2 = torch.tensor([[805274, 401]], dtype=tensor_type) x_sh3 = torch.tensor([[-1732294, -4102]], dtype=tensor_type) bit_sh_1, bit_sh_2, bit_sh_3 = bit b_sh1 = torch.tensor([bit_sh_1], dtype=tensor_type) b_sh2 = torch.tensor([bit_sh_2], dtype=tensor_type) b_sh3 = torch.tensor([bit_sh_3], dtype=tensor_type) shares_x = [x_sh1, x_sh2, x_sh3] shares_b = [b_sh1, b_sh2, b_sh3] rst_list_x = ReplicatedSharedTensor.distribute_shares(shares=shares_x, session=session, ring_size=ring_size, config=config) rst_list_b = ReplicatedSharedTensor.distribute_shares(shares=shares_b, session=session, ring_size=ring_size, config=config) x = MPCTensor(shares=rst_list_x, session=session, shape=x_sh1.shape) b = MPCTensor(shares=rst_list_b, session=session, shape=b_sh1.shape) secret_x = ReplicatedSharedTensor.shares_sum(shares_x, ring_size) secret_b = ReplicatedSharedTensor.shares_sum(shares_b, ring_size) result = operator.xor(x, b) expected_res = secret_x ^ secret_b assert (result.reconstruct(decode=False) == expected_res).all()
def test_bin_xor(get_clients, bit, security) -> None: parties = get_clients(3) protocol = Falcon(security) session = Session(protocol=protocol, parties=parties) session.ring_size = 2 SessionManager.setup_mpc(session) ring_size = 2 sh_x = torch.tensor([[0, 1, 0], [1, 0, 1]], dtype=torch.bool) shares_x = [sh_x, sh_x, sh_x] rst_list_x = ReplicatedSharedTensor.distribute_shares(shares=shares_x, session=session, ring_size=ring_size) x = MPCTensor(shares=rst_list_x, session=session) x.shape = sh_x.shape sh_b = torch.tensor([bit], dtype=torch.bool) shares_b = [sh_b, sh_b, sh_b] rst_list_b = ReplicatedSharedTensor.distribute_shares(shares=shares_b, session=session, ring_size=ring_size) b = MPCTensor(shares=rst_list_b, session=session) b.shape = sh_b.shape secret_x = ReplicatedSharedTensor.shares_sum(shares_x, ring_size) secret_b = ReplicatedSharedTensor.shares_sum(shares_b, ring_size) result = operator.xor(x, b) expected_res = secret_x ^ secret_b assert (result.reconstruct(decode=False) == expected_res).all()
def test_prime_mul_private(get_clients, security): parties = get_clients(3) protocol = Falcon(security) session = Session(protocol=protocol, parties=parties) SessionManager.setup_mpc(session) ring_size = PRIME_NUMBER prime_op = ReplicatedSharedTensor.get_op(ring_size, "mul") sh1 = torch.tensor([[32, 12, 23], [17, 35, 7]], dtype=torch.uint8) sh2 = torch.tensor([[45, 66, 47], [19, 57, 2]], dtype=torch.uint8) shares1 = [sh1, sh1, sh1] shares2 = [sh2, sh2, sh2] rst_list1 = ReplicatedSharedTensor.distribute_shares(shares=shares1, session=session, ring_size=ring_size) rst_list2 = ReplicatedSharedTensor.distribute_shares(shares=shares2, session=session, ring_size=ring_size) tensor1 = MPCTensor(shares=rst_list1, session=session) tensor1.shape = sh1.shape tensor2 = MPCTensor(shares=rst_list2, session=session) tensor2.shape = sh2.shape secret1 = ReplicatedSharedTensor.shares_sum(shares1, ring_size) secret2 = ReplicatedSharedTensor.shares_sum(shares2, ring_size) result = operator.mul(tensor1, tensor2) expected_res = prime_op(secret1, secret2) assert (result.reconstruct(decode=False) == expected_res).all()
def test_bin_mul_private(get_clients, security): parties = get_clients(3) protocol = Falcon(security) session = Session(protocol=protocol, parties=parties) SessionManager.setup_mpc(session) ring_size = 2 bin_op = ReplicatedSharedTensor.get_op(ring_size, "mul") sh1 = torch.tensor([[0, 1, 0], [1, 0, 1]], dtype=torch.bool) sh2 = torch.tensor([[1, 1, 0], [0, 1, 1]], dtype=torch.bool) shares1 = [sh1, sh1, sh1] shares2 = [sh2, sh2, sh2] rst_list1 = ReplicatedSharedTensor.distribute_shares(shares=shares1, session=session, ring_size=ring_size) rst_list2 = ReplicatedSharedTensor.distribute_shares(shares=shares2, session=session, ring_size=ring_size) tensor1 = MPCTensor(shares=rst_list1, session=session) tensor1.shape = sh1.shape tensor2 = MPCTensor(shares=rst_list2, session=session) tensor2.shape = sh2.shape secret1 = ReplicatedSharedTensor.shares_sum(shares1, ring_size) secret2 = ReplicatedSharedTensor.shares_sum(shares2, ring_size) result = operator.mul(tensor1, tensor2) expected_res = bin_op(secret1, secret2) assert (result.reconstruct(decode=False) == expected_res).all()
def test_ops_prime_public_xor(get_clients, security, bit) -> None: parties = get_clients(3) protocol = Falcon(security) session = Session(protocol=protocol, parties=parties) SessionManager.setup_mpc(session) ring_size = PRIME_NUMBER sh1 = torch.tensor([[17, 44], [8, 20]], dtype=torch.uint8) sh2 = torch.tensor([[8, 51], [27, 52]], dtype=torch.uint8) sh3 = torch.tensor([[42, 40], [32, 63]], dtype=torch.uint8) shares = [sh1, sh2, sh3] rst_list = ReplicatedSharedTensor.distribute_shares(shares=shares, session=session, ring_size=ring_size) tensor = MPCTensor(shares=rst_list, session=session) tensor.shape = sh1.shape secret = ReplicatedSharedTensor.shares_sum(shares, ring_size) value = torch.tensor([bit], dtype=torch.uint8) result = operator.xor(tensor, value) expected_res = secret ^ value assert (result.reconstruct(decode=False) == expected_res).all()
def local_decomposition( x: ReplicatedSharedTensor, ring_size: str, bitwise: bool = False) -> List[List[List[ReplicatedSharedTensor]]]: """Performs local decomposition to generate shares of shares. Args: x (ReplicatedSharedTensor) : input RSTensor. ring_size (str) : Ring size to generate decomposed shares in. bitwise (bool): Perform bit level decomposition on bits if set. Returns: List[ReplicatedSharedTensor]: Decomposed shares in the given ring size. Raises: ValueError: If RSTensor does not have session uuid. ValueError: If the exactly three parties are not involved in the computation. """ if x.session_uuid is None: raise ValueError("Input RSTensor should have session_uuid") session = get_session(x.session_uuid) if session.nr_parties != NR_PARTIES: raise ValueError( "ABY3 local_decomposition algorithm requires 3 parties") ring_size = int(ring_size) tensor_type = get_type_from_ring(ring_size) rank = session.rank zero = torch.zeros(x.shares[0].shape).type(tensor_type) # Similar to triples, we have instances for the shares generated. share_lst: List[List[List[ReplicatedSharedTensor]]] = [] input_rst = [] if bitwise: ring_bits = get_nr_bits( session.ring_size) # for bit-wise decomposition input_rst = [x.bit_extraction(idx) for idx in range(ring_bits)] else: input_rst.append(x) for share in input_rst: shares = [[zero.clone(), zero.clone()] for i in range(NR_PARTIES)] shares[rank][0] = share.shares[0].clone().type(tensor_type) shares[(rank + 1) % NR_PARTIES][1] = (share.shares[1].clone().type(tensor_type)) rst_sh = [] for i in range(NR_PARTIES): rst = x.clone() rst.shares = shares[i] rst.ring_size = ring_size rst_sh.append(rst) share_lst.append(rst_sh) return share_lst
def test_different_session_ids() -> None: x = torch.tensor([1]) shares = [x, x] x_share = ReplicatedSharedTensor(shares=shares, session_uuid=uuid4()) y_share = ReplicatedSharedTensor(shares=shares, session_uuid=uuid4()) # Different session ids assert x_share != y_share
def test_different_ring_size() -> None: x = torch.tensor([1]) shares = [x, x] x_share = ReplicatedSharedTensor(shares=shares, ring_size=2**32) y_share = ReplicatedSharedTensor(shares=shares, ring_size=2**64) # Different ring_size assert x_share != y_share
def test_different_shares() -> None: x = torch.tensor([1]) shares1 = [x, x] y = torch.tensor([2]) shares2 = [y, y] session_id = uuid4() x_share = ReplicatedSharedTensor(shares=shares1, session_uuid=session_id) y_share = ReplicatedSharedTensor(shares=shares2, session_uuid=session_id) # Different shares list assert x_share != y_share
def test_same_session_id_and_data() -> None: x = torch.randn(1) shares1 = [x, x] y = torch.randn(1) shares2 = [y, y] session_id = uuid4() x_share = ReplicatedSharedTensor(shares=shares1, session_uuid=session_id) y_share = ReplicatedSharedTensor(shares=shares2, session_uuid=session_id) # Different shares list assert x_share != y_share
def test_hook_property(get_clients) -> None: clients = get_clients(3) session = Session(parties=clients) SessionManager.setup_mpc(session) x = torch.randn(1, 3) y = torch.randn(1, 3) shares = [x, y] rst = ReplicatedSharedTensor() rst.shares = shares assert (rst.T.shares[0] == x.T).all() assert (rst.T.shares[1] == y.T).all()
def test_ops_prime_share_public(op_str) -> None: op = getattr(operator, op_str) ring_size = PRIME_NUMBER prime_op = ReplicatedSharedTensor.get_op(ring_size, op_str) x = torch.tensor([[24, 34], [66, 1]], dtype=torch.uint8) y = torch.tensor([[34, 47], [45, 32]], dtype=torch.uint8) x_share = ReplicatedSharedTensor(shares=[x], ring_size=ring_size) expected_res = prime_op(x, y) result = op(x_share, y) result = result.shares[0] assert (result == expected_res).all()
def test_ops_bin_share_public(op_str) -> None: op = getattr(operator, op_str) ring_size = 2 bin_op = ReplicatedSharedTensor.get_op(ring_size, op_str) x = torch.tensor([[0, 1], [1, 1]], dtype=torch.bool) y = torch.tensor([[1, 0], [1, 1]], dtype=torch.bool) x_share = ReplicatedSharedTensor(shares=[x], ring_size=ring_size) expected_res = bin_op(x, y) result = op(x_share, y) result = result.shares[0] assert (result == expected_res).all()
def test_send_get(get_clients, precision=12, base=4) -> None: client = get_clients(1)[0] protocol = Falcon("semi-honest") session = Session(protocol=protocol, parties=[client]) SessionManager.setup_mpc(session) share1 = torch.Tensor([1.4, 2.34, 3.43]) share2 = torch.Tensor([1, 2, 3]) share3 = torch.Tensor([1.4, 2.34, 3.43]) session_uuid = session.rank_to_uuid[0] x_share = ReplicatedSharedTensor(shares=[share1, share2, share3], session_uuid=session_uuid) x_ptr = x_share.send(client) result = x_ptr.get() assert result == x_share
def test_different_config() -> None: x = torch.tensor([1]) shares = [x, x] session_id = uuid4() config1 = Config(encoder_precision=10, encoder_base=2) config2 = Config(encoder_precision=12, encoder_base=10) x_share = ReplicatedSharedTensor(shares=shares, session_uuid=session_id, config=config1) y_share = ReplicatedSharedTensor(shares=shares, session_uuid=session_id, config=config2) # Different fixed point config assert x_share != y_share
def truncation_algorithm1( ptr_list: List[torch.Tensor], shape: torch.Size, session: Session, ring_size: int, config: Config, ) -> List[ReplicatedSharedTensor]: """Performs the ABY3 truncation algorithm1. Args: ptr_list (List[torch.Tensor]): Tensors to truncate shape (torch.Size) : shape of tensor values session (Session) : session the tensor belong to ring_size (int): Ring size of the underlying tensors. config (Config): The configuration(base,precision) of the underlying tensors. Returns: List["ReplicatedSharedTensor"] : Truncated shares. """ tensor_type = get_type_from_ring(ring_size) rand_value = torch.empty(size=shape, dtype=tensor_type).random_(generator=gen) base = config.encoder_base precision = config.encoder_precision scale = base**precision x1, x2, x3 = ptr_list x1_trunc = x1 >> precision if base == 2 else x1 // scale x_trunc = (x2 + x3) >> precision if base == 2 else (x2 + x3) // scale shares = [x1_trunc, x_trunc - rand_value, rand_value] ptr_list = ReplicatedSharedTensor.distribute_shares( shares, session, ring_size, config) return ptr_list
def compute_zvalue_and_add_mask( x: ReplicatedSharedTensor, y: ReplicatedSharedTensor, op_str: str, **kwargs: Dict[Any, Any], ) -> torch.Tensor: """Operation to compute local z share and add mask to it. Args: x (ReplicatedSharedTensor): Secret. y (ReplicatedSharedTensor): Another secret. op_str (str): Operation string. kwargs (Dict[Any, Any]): Kwargs for some operations like conv2d Returns: share (Torch.tensor): The masked local z share. """ # Parties calculate z value locally session = get_session(x.session_uuid) z_value = Falcon.multiplication_protocol(x, y, op_str, **kwargs) shape = MPCTensor._get_shape(op_str, x.shape, y.shape) przs_mask = session.przs_generate_random_share(shape=shape, ring_size=str( x.ring_size)) # Add PRZS Mask to z value op = ReplicatedSharedTensor.get_op(x.ring_size, "add") share = op(z_value, przs_mask.get_shares()[0]) return share
def multiplication_protocol( x: ReplicatedSharedTensor, y: ReplicatedSharedTensor, op_str: str, **kwargs: Dict[Any, Any], ) -> ReplicatedSharedTensor: """Implementation of Falcon's multiplication with semi-honest security guarantee. Args: x (ReplicatedSharedTensor): Secret y (ReplicatedSharedTensor): Another secret op_str (str): Operator string. kwargs (Dict[Any, Any]): Keywords arguments for the operator. Returns: shares (ReplicatedSharedTensor): results in terms of ReplicatedSharedTensor. """ op = ReplicatedSharedTensor.get_op(x.ring_size, op_str) z_value = shares_sum( [ op(x.shares[0], y.shares[0], **kwargs), op(x.shares[1], y.shares[0], **kwargs), op(x.shares[0], y.shares[1], **kwargs), ], x.ring_size, ) return z_value
def prrs_generate_random_share( self, shape: Union[tuple, torch.Size], ) -> Any: """Generates a random share using the generators held by a party. Args: shape (Union[tuple, torch.Size]): Shape for the share. Returns: Any: ShareTensor or ReplicatedSharedTensor """ from sympc.tensor import ReplicatedSharedTensor from sympc.tensor import ShareTensor share1, share2 = self._generate_random_share(shape) if self.protocol.share_class == ShareTensor: # It has encoder_precision = 0 such that the value would not be encoded share = ShareTensor( data=share1, session_uuid=self.uuid, config=Config(encoder_precision=0), ) else: share = ReplicatedSharedTensor( shares=[share1, share2], session_uuid=self.uuid, config=Config(encoder_precision=0), ) return share
def test_fixed_point(precision, base) -> None: x = torch.tensor([1.25, 3.301]) shares = [x, x] rst = ReplicatedSharedTensor(shares=shares, config=Config(encoder_precision=precision, encoder_base=base)) fp_encoder = FixedPointEncoder(precision=precision, base=base) tensor_type = get_type_from_ring(rst.ring_size) for i in range(len(shares)): shares[i] = fp_encoder.encode(shares[i]).to(tensor_type) assert (torch.cat(shares) == torch.cat(rst.shares)).all() for i in range(len(shares)): shares[i] = fp_encoder.decode(shares[i].type(torch.LongTensor)) assert (torch.cat(shares) == torch.cat(rst.decode())).all()
def triple_verification( z_sh: ReplicatedSharedTensor, eps: torch.Tensor, delta: torch.Tensor, op_str: str, **kwargs: Dict[Any, Any], ) -> ReplicatedSharedTensor: """Performs Beaver's triple verification check. Args: z_sh (ReplicatedSharedTensor) : share of multiplied value(x*y). eps (torch.Tensor) :masked value of x delta (torch.Tensor): masked value of y op_str (str): Operator string. kwargs (Dict[Any, Any]): Keywords arguments for the operator. Returns: ReplicatedSharedTensor : Result of the verification. """ session = get_session(z_sh.session_uuid) ring_size = z_sh.ring_size crypto_store = session.crypto_store eps_shape = tuple(eps.shape) delta_shape = tuple(delta.shape) primitives = crypto_store.get_primitives_from_store( f"beaver_{op_str}", eps_shape, delta_shape) a_share, b_share, c_share = primitives op = ReplicatedSharedTensor.get_op(ring_size, op_str) eps_delta = op(eps, delta, **kwargs) eps_b = b_share.clone() delta_a = a_share.clone() # prevent re-encoding as the values are encoded. # TODO: should be improved. for i in range(2): eps_b.shares[i] = op(eps, eps_b.shares[i]) delta_a.shares[i] = op(delta_a.shares[i], delta) rst_share = c_share + delta_a + eps_b if session.rank == 0: rst_share.shares[0] = shares_sum([rst_share.shares[0], eps_delta], ring_size) if session.rank == 2: rst_share.shares[1] = shares_sum([rst_share.shares[1], eps_delta], ring_size) result = z_sh - rst_share return result
def truncate(x: MPCTensor, session: Session, ring_size: int, config: Config) -> MPCTensor: """Perfoms the ABY3 truncation algorithm. Args: x (MPCTensor): input tensor session (Session) : session of the input tensor. ring_size (int): Ring size of the underlying tensor. config (Config) : The configuration(base,precision) of the underlying tensor. Returns: MPCTensor: truncated MPCTensor. Raises: ValueError : parties involved in the computation is not equal to three. ValueError : Invalid MPCTensor share pointers. TODO :Switch to trunc2 algorithm as it is communication efficient. """ if session.nr_parties != NR_PARTIES: raise ValueError( "Share truncation algorithm 1 works only for 3 parites.") # RSPointer - public ops, Tensor Pointer - Private ops ptr_list = [] ptr_name = x.share_ptrs[0].__name__ # TODO:Shoud be concised,lot of branching done,to improve communication efficiency. if ptr_name == "ReplicatedSharedTensorPointer": if ring_size in {2, PRIME_NUMBER}: share_ptrs = x.share_ptrs else: ptr_list.append(x.share_ptrs[0].get_shares()[0].get_copy()) ptr_list.extend(x.share_ptrs[1].get_copy().shares) share_ptrs = ABY3.truncation_algorithm1( ptr_list, x.shape, session, ring_size, config) elif ptr_name == "TensorPointer": ptr_list = [share.get_copy() for share in x.share_ptrs] if ring_size in {2, PRIME_NUMBER}: share_ptrs = ReplicatedSharedTensor.distribute_shares( ptr_list, session, ring_size, config) else: share_ptrs = ABY3.truncation_algorithm1( ptr_list, x.shape, session, ring_size, config) else: raise ValueError("{ptr_name} not supported.") result = MPCTensor(shares=share_ptrs, session=session, shape=x.shape) return result
def przs_generate_random_share( self, shape: Union[tuple, torch.Size], ring_size: Optional[str] = None, ) -> Any: """Generates a random zero share using the two generators held by a party. Args: shape (Union[tuple, torch.Size]): Shape for the share. ring_size (str): ring size to generate share. Returns: Any: ShareTensor or ReplicatedSharedTensor """ from sympc.tensor import ReplicatedSharedTensor from sympc.tensor import ShareTensor if ring_size is None: ring_size = self.ring_size else: ring_size = int(ring_size) # 2**64 cannot be serialized. current_share, next_share = self._generate_random_share(shape, ring_size) if self.protocol.share_class == ShareTensor: # It has encoder_precision = 0 such that the value would not be encoded share = ShareTensor( data=current_share - next_share, session_uuid=self.uuid, config=Config(encoder_precision=0), ring_size=ring_size, ) else: op = ReplicatedSharedTensor.get_op(ring_size, "sub") share = ReplicatedSharedTensor( shares=[op(current_share, next_share)], session_uuid=self.uuid, config=Config(encoder_precision=0), ring_size=ring_size, ) return share
def test_ops_share_private(op_str, precision, base) -> None: op = getattr(operator, op_str) x = torch.Tensor([[0.125, -1.25], [-4.25, 4]]) y = torch.Tensor([[4.5, -2.5], [5, 2.25]]) x_share = ReplicatedSharedTensor(shares=[x], config=Config( encoder_base=base, encoder_precision=precision)) y_share = ReplicatedSharedTensor(shares=[y], config=Config( encoder_base=base, encoder_precision=precision)) expected_res = op(x, y) res = op(x_share, y_share) tensor_decoded = res.fp_encoder.decode(res.shares[0]) assert np.allclose(tensor_decoded, expected_res, rtol=base**-precision)
def test_share_distribution_number_shares(get_clients, parties): parties = get_clients(parties) protocol = Falcon("semi-honest") session = Session(protocol=protocol, parties=parties) SessionManager.setup_mpc(session) shares = MPCTensor.generate_shares(100.42, len(parties)) share_ptrs = ReplicatedSharedTensor.distribute_shares(shares, session) for RSTensor in share_ptrs: assert len(RSTensor.get_shares().get()) == (len(parties) - 1)
def test_prime_xor(get_clients, security, bit) -> None: parties = get_clients(3) protocol = Falcon(security) session = Session(protocol=protocol, parties=parties) session.ring_size = PRIME_NUMBER SessionManager.setup_mpc(session) ring_size = PRIME_NUMBER x_sh1 = torch.tensor([[17, 44], [8, 20]], dtype=torch.uint8) x_sh2 = torch.tensor([[8, 51], [27, 52]], dtype=torch.uint8) x_sh3 = torch.tensor([[42, 40], [32, 63]], dtype=torch.uint8) bit_sh_1, bit_sh_2, bit_sh_3 = bit b_sh1 = torch.tensor([bit_sh_1], dtype=torch.uint8) b_sh2 = torch.tensor([bit_sh_2], dtype=torch.uint8) b_sh3 = torch.tensor([bit_sh_3], dtype=torch.uint8) shares_x = [x_sh1, x_sh2, x_sh3] shares_b = [b_sh1, b_sh2, b_sh3] rst_list_x = ReplicatedSharedTensor.distribute_shares(shares=shares_x, session=session, ring_size=ring_size) rst_list_b = ReplicatedSharedTensor.distribute_shares(shares=shares_b, session=session, ring_size=ring_size) x = MPCTensor(shares=rst_list_x, session=session) b = MPCTensor(shares=rst_list_b, session=session) x.shape = x_sh1.shape b.shape = b_sh1.shape secret_x = ReplicatedSharedTensor.shares_sum(shares_x, ring_size) secret_b = ReplicatedSharedTensor.shares_sum(shares_b, ring_size) result = operator.xor(x, b) expected_res = secret_x ^ secret_b assert (result.reconstruct(decode=False) == expected_res).all()
def proto2object(proto: ReplicatedSharedTensor_PB) -> ReplicatedSharedTensor: if proto.session_uuid: session = sympc.session.get_session(proto.session_uuid) if session is None: raise ValueError(f"The session {proto.session_uuid} could not be found") config = dataclasses.asdict(session.config) else: config = syft.deserialize(proto.config, from_proto=True) output_shares = [] for tensor in proto.tensor: output_shares.append(protobuf_tensor_deserializer(tensor)) share = ReplicatedSharedTensor(shares=None, config=Config(**config)) if proto.session_uuid: share.session_uuid = UUID(proto.session_uuid) share.shares = output_shares return share
def test_ops_bin_public_xor(get_clients, security, bit) -> None: parties = get_clients(3) protocol = Falcon(security) session = Session(protocol=protocol, parties=parties) SessionManager.setup_mpc(session) ring_size = 2 sh = torch.tensor([[0, 1, 0], [1, 0, 1]], dtype=torch.bool) shares = [sh, sh, sh] rst_list = ReplicatedSharedTensor.distribute_shares(shares=shares, session=session, ring_size=ring_size) tensor = MPCTensor(shares=rst_list, session=session) tensor.shape = sh.shape secret = ReplicatedSharedTensor.shares_sum(shares, ring_size) value = torch.tensor([bit], dtype=torch.bool) result = operator.xor(tensor, value) expected_res = secret ^ value assert (result.reconstruct(decode=False) == expected_res).all()
def test_ops_prime_public_mul(get_clients, security) -> None: parties = get_clients(3) protocol = Falcon(security) session = Session(protocol=protocol, parties=parties) SessionManager.setup_mpc(session) ring_size = PRIME_NUMBER bin_op = ReplicatedSharedTensor.get_op(ring_size, "mul") sh = torch.tensor([[33, 45, 0], [52, 41, 22]], dtype=torch.uint8) shares = [sh, sh, sh] rst_list = ReplicatedSharedTensor.distribute_shares(shares=shares, session=session, ring_size=ring_size) tensor = MPCTensor(shares=rst_list, session=session) tensor.shape = sh.shape secret = ReplicatedSharedTensor.shares_sum(shares, ring_size) value = torch.tensor([66], dtype=torch.uint8) result = operator.mul(tensor, value) expected_res = bin_op(secret, value) assert (result.reconstruct(decode=False) == expected_res).all()
def mul_semi_honest( x: MPCTensor, y: MPCTensor, session: Session, op_str: str, ring_size: int, config: Config, reshare: bool = False, **kwargs_: Dict[Any, Any], ) -> MPCTensor: """Falcon semihonest multiplication. Performs Falcon's mul implementation, add masks and performs resharing. Args: x (MPCTensor): Secret y (MPCTensor): Another secret session (Session): Session the tensors belong to op_str (str): Operation string. ring_size (int) : Ring size of the underlying tensors. config (Config): The configuration(base,precision) of the underlying tensor. reshare (bool) : Convert 3-out-3 to 2-out-3 if set. kwargs_ (Dict[Any, Any]): Kwargs for some operations like conv2d Returns: MPCTensor: Result of the operation. """ args = [[x_share, y_share, op_str] for x_share, y_share in zip(x.share_ptrs, y.share_ptrs)] z_shares_ptrs = parallel_execution(Falcon.compute_zvalue_and_add_mask, session.parties)(args, kwargs_) result = MPCTensor(shares=z_shares_ptrs, session=x.session) if reshare: z_shares = [share.get() for share in z_shares_ptrs] # Convert 3-3 shares to 2-3 shares by resharing reshared_shares = ReplicatedSharedTensor.distribute_shares( z_shares, x.session, ring_size, config) result = MPCTensor(shares=reshared_shares, session=x.session) result.shape = MPCTensor._get_shape(op_str, x.shape, y.shape) # for prrs return result