def __init__( self, shares: Optional[List[Union[float, int, torch.Tensor]]] = None, config: Config = Config(encoder_base=2, encoder_precision=16), session_uuid: Optional[UUID] = None, ring_size: int = 2**64, ): """Initialize ReplicatedSharedTensor. Args: shares (Optional[List[Union[float, int, torch.Tensor]]]): Shares list from which RSTensor is created. config (Config): The configuration where we keep the encoder precision and base. session_uuid (Optional[UUID]): Used to keep track of a share that is associated with a remote session ring_size (int): field used for the operations applied on the shares Defaults to 2**64 """ self.session_uuid = session_uuid self.ring_size = ring_size self.config = config self.fp_encoder = FixedPointEncoder(base=config.encoder_base, precision=config.encoder_precision) tensor_type = get_type_from_ring(ring_size) self.shares = [] if shares is not None: for i in range(len(shares)): self.shares.append(self._encode(shares[i]).to(tensor_type))
def __init__( self, data: Optional[Union[float, int, torch.Tensor]] = None, config: Config = Config(encoder_base=2, encoder_precision=16), session_uuid: Optional[UUID] = None, ring_size: int = 2**64, ) -> None: """Initialize ShareTensor. Args: data (Optional[Any]): The share a party holds. Defaults to None config (Config): The configuration where we keep the encoder precision and base. session_uuid (Optional[UUID]): Used to keep track of a share that is associated with a remote session ring_size (int): field used for the operations applied on the shares Defaults to 2**64 """ self.session_uuid = session_uuid self.ring_size = ring_size self.config = config self.fp_encoder = FixedPointEncoder(base=config.encoder_base, precision=config.encoder_precision) self.tensor: Optional[torch.Tensor] = None if data is not None: tensor_type = get_type_from_ring(ring_size) self.tensor = self._encode(data).to(tensor_type)
def _generate_random_share( self, shape: Union[tuple, torch.Size], ring_size: int ) -> List[torch.Tensor]: """Generate random tensor share for the given shape and ring_size. Args: shape (Union[tuple, torch.Size]): Shape for the share. ring_size (int): ring size to generate share. The generators are invoked in Counter(CTR) mode as parties with the same initial seeds could generate correlated random numbers on subsequent invocations. Returns: List[torch.Tensor] : shares generated by the generators. """ from sympc.tensor import PRIME_NUMBER gen1, gen2 = self.przs_generators tensor_type = get_type_from_ring(ring_size) max_val = PRIME_NUMBER if ring_size == PRIME_NUMBER else None tensor_share1 = generate_random_element( tensor_type=tensor_type, generator=gen1, shape=shape, max_val=max_val ) tensor_share2 = generate_random_element( tensor_type=tensor_type, generator=gen2, shape=shape, max_val=max_val ) return [tensor_share1, tensor_share2]
def truncation_algorithm1( ptr_list: List[torch.Tensor], shape: torch.Size, session: Session, ring_size: int, config: Config, ) -> List[ReplicatedSharedTensor]: """Performs the ABY3 truncation algorithm1. Args: ptr_list (List[torch.Tensor]): Tensors to truncate shape (torch.Size) : shape of tensor values session (Session) : session the tensor belong to ring_size (int): Ring size of the underlying tensors. config (Config): The configuration(base,precision) of the underlying tensors. Returns: List["ReplicatedSharedTensor"] : Truncated shares. """ tensor_type = get_type_from_ring(ring_size) rand_value = torch.empty(size=shape, dtype=tensor_type).random_(generator=gen) base = config.encoder_base precision = config.encoder_precision scale = base**precision x1, x2, x3 = ptr_list x1_trunc = x1 >> precision if base == 2 else x1 // scale x_trunc = (x2 + x3) >> precision if base == 2 else (x2 + x3) // scale shares = [x1_trunc, x_trunc - rand_value, rand_value] ptr_list = ReplicatedSharedTensor.distribute_shares( shares, session, ring_size, config) return ptr_list
def local_decomposition( x: ReplicatedSharedTensor, ring_size: str, bitwise: bool = False) -> List[List[List[ReplicatedSharedTensor]]]: """Performs local decomposition to generate shares of shares. Args: x (ReplicatedSharedTensor) : input RSTensor. ring_size (str) : Ring size to generate decomposed shares in. bitwise (bool): Perform bit level decomposition on bits if set. Returns: List[ReplicatedSharedTensor]: Decomposed shares in the given ring size. Raises: ValueError: If RSTensor does not have session uuid. ValueError: If the exactly three parties are not involved in the computation. """ if x.session_uuid is None: raise ValueError("Input RSTensor should have session_uuid") session = get_session(x.session_uuid) if session.nr_parties != NR_PARTIES: raise ValueError( "ABY3 local_decomposition algorithm requires 3 parties") ring_size = int(ring_size) tensor_type = get_type_from_ring(ring_size) rank = session.rank zero = torch.zeros(x.shares[0].shape).type(tensor_type) # Similar to triples, we have instances for the shares generated. share_lst: List[List[List[ReplicatedSharedTensor]]] = [] input_rst = [] if bitwise: ring_bits = get_nr_bits( session.ring_size) # for bit-wise decomposition input_rst = [x.bit_extraction(idx) for idx in range(ring_bits)] else: input_rst.append(x) for share in input_rst: shares = [[zero.clone(), zero.clone()] for i in range(NR_PARTIES)] shares[rank][0] = share.shares[0].clone().type(tensor_type) shares[(rank + 1) % NR_PARTIES][1] = (share.shares[1].clone().type(tensor_type)) rst_sh = [] for i in range(NR_PARTIES): rst = x.clone() rst.shares = shares[i] rst.ring_size = ring_size rst_sh.append(rst) share_lst.append(rst_sh) return share_lst
def test_session_init(): """Test correct initialisation of the Sessin class.""" # Test default init session = Session() assert isinstance(session.uuid, UUID) assert session.parties == [] assert session.trusted_third_party is None assert session.crypto_store is None assert session.protocol is not None assert isinstance(session.config, Config) assert session.przs_generators == [] assert session.rank == -1 assert session.session_ptrs == [] assert session.tensor_type == get_type_from_ring(2**64) assert session.ring_size == 2**64 assert session.min_value == -(2**64) // 2 assert session.max_value == (2**64 - 1) // 2 # Test custom init uuid = uuid4() config = Config() session = Session(parties=["alice", "bob"], ring_size=2**32, config=config, ttp="TTP", uuid=uuid) assert session.uuid == uuid assert session.parties == ["alice", "bob"] assert session.trusted_third_party == "TTP" assert session.crypto_store is None assert session.protocol is not None assert session.config == config assert session.przs_generators == [] assert session.rank == -1 assert session.session_ptrs == [] assert session.tensor_type == get_type_from_ring(2**32) assert session.ring_size == 2**32 assert session.min_value == -(2**32) // 2 assert session.max_value == (2**32 - 1) // 2
def test_session_default_init() -> None: """Test correct initialisation of the Sessin class.""" # Test default init session = Session() assert session.uuid is None assert session.parties == [] assert session.trusted_third_party is None assert session.crypto_store is None assert session.protocol is not None assert isinstance(session.config, Config) assert session.przs_generators == [] assert session.rank == -1 assert session.session_ptrs == [] assert session.tensor_type == get_type_from_ring(2**64) assert session.ring_size == 2**64 assert session.min_value == -(2**64) // 2 assert session.max_value == (2**64 - 1) // 2
def test_fixed_point(precision, base) -> None: x = torch.tensor([1.25, 3.301]) shares = [x, x] rst = ReplicatedSharedTensor(shares=shares, config=Config(encoder_precision=precision, encoder_base=base)) fp_encoder = FixedPointEncoder(precision=precision, base=base) tensor_type = get_type_from_ring(rst.ring_size) for i in range(len(shares)): shares[i] = fp_encoder.encode(shares[i]).to(tensor_type) assert (torch.cat(shares) == torch.cat(rst.shares)).all() for i in range(len(shares)): shares[i] = fp_encoder.decode(shares[i].type(torch.LongTensor)) assert (torch.cat(shares) == torch.cat(rst.decode())).all()
def select_shares(x: MPCTensor, y: MPCTensor, b: MPCTensor) -> MPCTensor: """Returns either x or y based on bit b. Args: x (MPCTensor): input tensor y (MPCTensor): input tensor b (MPCTensor): input tensor which is shares of a bit used as selector bit. Returns: z (MPCTensor):Returns x (if b==0) or y (if b==1). Raises: ValueError: If the selector bit tensor is not of ring size "2". """ ring_size = int(b.share_ptrs[0].get_ring_size().get_copy()) shape = b.shape if ring_size != 2: raise ValueError( f"Invalid {ring_size} for selector bit,must be of ring size 2") if shape is None: raise ValueError( "The selector bit tensor must have a valid shape.") session = x.session # TODO: Should be made to generate with CryptoProvider in Preprocessing stage. c_ptrs: List[ReplicatedSharedTensor] = [] for session_ptr in session.session_ptrs: c_ptrs.append( session_ptr.prrs_generate_random_share( shape=shape, ring_size=str(ring_size))) c = MPCTensor(shares=c_ptrs, session=session, shape=shape) # bit random share c_r = ABY3.bit_injection( c, session, session.ring_size) # bit random share in session ring. tensor_type = get_type_from_ring(session.ring_size) mask = (b ^ c).reconstruct(decode=False).type(tensor_type) d = (mask - (c_r * mask)) + (c_r * (mask ^ 1)) # Order placed carefully to prevent re-encoding,should not be changed. z = x + (d * (y - x)) return z
def test_session_custom_init() -> None: config = Config() session = Session(parties=["alice", "bob"], ring_size=2**32, config=config, ttp="TTP") assert session.uuid is None assert session.parties == ["alice", "bob"] assert session.trusted_third_party == "TTP" assert session.crypto_store is None assert session.protocol is not None assert session.config == config assert session.przs_generators == [] assert session.rank == -1 assert session.session_ptrs == [] assert session.tensor_type == get_type_from_ring(2**32) assert session.ring_size == 2**32 assert session.min_value == -(2**32) // 2 assert session.max_value == (2**32 - 1) // 2
def test_private_compare(get_clients, security) -> None: parties = get_clients(3) falcon = Falcon(security_type=security) session = Session(parties=parties, protocol=falcon) SessionManager.setup_mpc(session) base = session.config.encoder_base precision = session.config.encoder_precision fp_encoder = FixedPointEncoder(base=base, precision=precision) secret = torch.tensor([[358.85, 79.29], [67.78, 2415.50]]) r = torch.tensor([[357.05, 90], [145.32, 2400.54]]) r = fp_encoder.encode(r) x = MPCTensor(secret=secret, session=session) x_b = ABY3.bit_decomposition_ttp(x, session) # bit shares x_p = [] # prime ring shares for share in x_b: x_p.append(ABY3.bit_injection(share, session, PRIME_NUMBER)) tensor_type = get_type_from_ring(session.ring_size) result = Falcon.private_compare(x_p, r.type(tensor_type)) expected_res = torch.tensor([[1, 0], [0, 1]], dtype=torch.bool) assert (result.reconstruct(decode=False) == expected_res).all()
def __init__( self, parties: Optional[List[Any]] = None, ring_size: int = 2**64, config: Optional[Config] = None, protocol: Optional[Protocol] = None, ttp: Optional[Any] = None, ) -> None: """Initializer for the Session. Args: parties (Optional[List[Any]): Used to send/receive messages: ring_size (int): Field used for the operations applied on the shares config (Optional[Config]): Configuration used for information needed by the Fixed Point Encoder. Defaults None protocol (Optional[str]): Protocol. Defaults None ttp (Optional[Any]): Trusted third party. Defaults None. Raises: ValueError: If protocol is not registered. """ # Each worker will have the rank as the index in the list # Only the party that is the CC (Control Center) will have access # to this self.parties: List[Any] self.nr_parties: int if parties is None: self.parties = [] self.nr_parties = 0 else: self.parties = parties self.nr_parties = len(parties) # Some protocols require a trusted third party # Ex: SPDZ self.trusted_third_party = ttp # The CryptoStore is initialized at each party when it is unserialized self.crypto_store: Optional[Dict[ Any, Any]] = None # TODO: this should be CryptoStore self.protocol: Protocol = None if protocol is None: self.protocol = Protocol.registered_protocols["FSS"]() else: if type(protocol).__name__ not in Protocol.registered_protocols: raise ValueError(f"{type(protocol).__name__} not registered!") self.protocol = protocol if (self.parties and len(self.parties) < 3 and self.protocol.security_type == "malicious"): raise ValueError( "Malicious security cannot be provided to less than 3 parties") self.config = config if config else Config() self.przs_generators: List[Optional[torch.Generator]] = [] # Those will be populated in the setup_mpc self.rank: int = -1 self.uuid: Optional[UUID] = None self.session_ptrs = [] self.rank_to_uuid: Dict[int, UUID] = {} # Ring size self.tensor_type: Union[torch.dtype] = get_type_from_ring(ring_size) self.ring_size = ring_size self.min_value = -(ring_size) // 2 self.max_value = (ring_size - 1) // 2 self.autograd_active = False
def __init__( self, parties: Optional[List[Any]] = None, ring_size: int = 2**64, config: Optional[Config] = None, protocol: Optional[str] = "FSS", ttp: Optional[Any] = None, uuid: Optional[UUID] = None, ) -> None: """Initializer for the Session. Args: parties (Optional[List[Any]): Used to send/receive messages: ring_size (int): Field used for the operations applied on the shares config (Optional[Config]): Configuration used for information needed by the Fixed Point Encoder. Defaults None protocol (Optional[str]): Protocol. Defaults None ttp (Optional[Any]): Trusted third party. Defaults None. uuid (Optional[UUID]): Universal Identifier for the session. Defaults None Raises: ValueError: If protocol is not registered. """ self.uuid = uuid4() if uuid is None else uuid # Each worker will have the rank as the index in the list # Only the party that is the CC (Control Center) will have access # to this self.parties: List[Any] self.nr_parties: int if parties is None: self.parties = [] self.nr_parties = 0 else: self.parties = parties self.nr_parties = len(parties) # Some protocols require a trusted third party # Ex: SPDZ self.trusted_third_party = ttp # The CryptoStore is initialized at each party when it is unserialized self.crypto_store: Optional[Dict[ Any, Any]] = None # TODO: this should be CryptoStore if protocol not in Protocol.registered_protocols: raise ValueError(f"{protocol} not registered!") self.protocol: Protocol = Protocol.registered_protocols[protocol] self.config = config if config else Config() self.przs_generators: List[List[torch.Generator]] = [] # Those will be populated in the setup_mpc self.rank = -1 self.session_ptrs: List[Session] = [] # Ring size self.tensor_type: Union[torch.dtype] = get_type_from_ring(ring_size) self.ring_size = ring_size self.min_value = -(ring_size) // 2 self.max_value = (ring_size - 1) // 2