def test_fp_encoding(): """Test correct encoding with FixedPointEncoder.""" fp_encoder = FixedPointEncoder() # Test encoding with tensors. tensor = torch.Tensor([1, 2, 3]) encoded_tensor = fp_encoder.encode(tensor) target_tensor = torch.LongTensor([1, 2, 3]) * fp_encoder.scale assert (encoded_tensor == target_tensor).all() # Test encoding with foats. test_float = 42.0 encoded_float = fp_encoder.encode(test_float) target_float = torch.LongTensor([42]) * fp_encoder.scale assert (encoded_float == target_float).all() # Test encoding with ints. test_int = 2 encoded_int = fp_encoder.encode(test_int) target_int = torch.LongTensor([2]) * fp_encoder.scale assert (encoded_int == target_int).all()
def test_fixed_point(precision, base) -> None: x = torch.tensor([1.25, 3.301]) shares = [x, x] rst = ReplicatedSharedTensor(shares=shares, config=Config(encoder_precision=precision, encoder_base=base)) fp_encoder = FixedPointEncoder(precision=precision, base=base) tensor_type = get_type_from_ring(rst.ring_size) for i in range(len(shares)): shares[i] = fp_encoder.encode(shares[i]).to(tensor_type) assert (torch.cat(shares) == torch.cat(rst.shares)).all() for i in range(len(shares)): shares[i] = fp_encoder.decode(shares[i].type(torch.LongTensor)) assert (torch.cat(shares) == torch.cat(rst.decode())).all()
def test_private_compare(get_clients, security) -> None: parties = get_clients(3) falcon = Falcon(security_type=security) session = Session(parties=parties, protocol=falcon) SessionManager.setup_mpc(session) base = session.config.encoder_base precision = session.config.encoder_precision fp_encoder = FixedPointEncoder(base=base, precision=precision) secret = torch.tensor([[358.85, 79.29], [67.78, 2415.50]]) r = torch.tensor([[357.05, 90], [145.32, 2400.54]]) r = fp_encoder.encode(r) x = MPCTensor(secret=secret, session=session) x_b = ABY3.bit_decomposition_ttp(x, session) # bit shares x_p = [] # prime ring shares for share in x_b: x_p.append(ABY3.bit_injection(share, session, PRIME_NUMBER)) tensor_type = get_type_from_ring(session.ring_size) result = Falcon.private_compare(x_p, r.type(tensor_type)) expected_res = torch.tensor([[1, 0], [0, 1]], dtype=torch.bool) assert (result.reconstruct(decode=False) == expected_res).all()
class ShareTensor(metaclass=SyMPCTensor): """Single Share representation. Arguments: data (Optional[Any]): the share a party holds session (Optional[Any]): the session from which this shares belongs to encoder_base (int): the base for the encoder encoder_precision (int): the precision for the encoder ring_size (int): field used for the operations applied on the shares Attributes: Syft Serializable Attributes id (UID): the id to store the session tags (Optional[List[str]): an optional list of strings that are tags used at search description (Optional[str]): an optional string used to describe the session tensor (Any): the value of the share session_uuid (Optional[UUID]): keep track from which session the share belongs to encoder_precision (int): precision for the encoder encoder_base (int): base for the encoder """ __slots__ = { # Populated in Syft "id", "tags", "description", "tensor", "session_uuid", "config", "fp_encoder", "ring_size", } # Used by the SyMPCTensor metaclass METHODS_FORWARD: Set[str] = { "numel", "squeeze", "unsqueeze", "t", "view", "expand", "sum", "clone", "flatten", "reshape", "repeat", "narrow", "dim", "transpose", "roll", } PROPERTIES_FORWARD: Set[str] = {"T", "shape"} def __init__( self, data: Optional[Union[float, int, torch.Tensor]] = None, config: Config = Config(encoder_base=2, encoder_precision=16), session_uuid: Optional[UUID] = None, ring_size: int = 2**64, ) -> None: """Initialize ShareTensor. Args: data (Optional[Any]): The share a party holds. Defaults to None config (Config): The configuration where we keep the encoder precision and base. session_uuid (Optional[UUID]): Used to keep track of a share that is associated with a remote session ring_size (int): field used for the operations applied on the shares Defaults to 2**64 """ self.session_uuid = session_uuid self.ring_size = ring_size self.config = config self.fp_encoder = FixedPointEncoder(base=config.encoder_base, precision=config.encoder_precision) self.tensor: Optional[torch.Tensor] = None if data is not None: tensor_type = get_type_from_ring(ring_size) self.tensor = self._encode(data).to(tensor_type) def _encode(self, data): return self.fp_encoder.encode(data) def decode(self): """Decode via FixedPrecisionEncoder. Returns: torch.Tensor: Decoded value """ return self._decode() def _decode(self): return self.fp_encoder.decode(self.tensor.type(torch.LongTensor)) @staticmethod def sanity_checks(x: "ShareTensor", y: Union[int, float, torch.Tensor, "ShareTensor"], op_str: str) -> "ShareTensor": """Check the type of "y" and covert it to share if necessary. Args: x (ShareTensor): Typically "self". y (Union[int, float, torch.Tensor, "ShareTensor"]): Tensor to check. op_str (str): String operator. Returns: ShareTensor: the converted y value. Raises: ValueError: if both values are shares and they have different uuids """ if not isinstance(y, ShareTensor): if x.session_uuid is not None: session = sympc.session.get_session(str(x.session_uuid)) ring_size = session.ring_size config = session.config else: ring_size = x.ring_size config = x.config y = ShareTensor(data=y, ring_size=ring_size, config=config) elif y.session_uuid and x.session_uuid and y.session_uuid != x.session_uuid: raise ValueError( f"Session UUIDs did not match {x.session_uuid} {y.session_uuid}" ) return y def apply_function(self, y: Union["ShareTensor", torch.Tensor, int, float], op_str: str) -> "ShareTensor": """Apply a given operation. Args: y (Union["ShareTensor", torch.Tensor, int, float]): tensor to apply the operator. op_str (str): Operator. Returns: ShareTensor: Result of the operation. """ op = getattr(operator, op_str) if isinstance(y, ShareTensor): value = op(self.tensor, y.tensor) else: value = op(self.tensor, y) session_uuid = self.session_uuid or y.session_uuid if session_uuid is not None: session = sympc.session.get_session(str(session_uuid)) ring_size = session.ring_size config = session.config else: # Use the values from "self" ring_size = self.ring_size config = self.config res = ShareTensor(ring_size=ring_size, session_uuid=session_uuid, config=config) res.tensor = value return res def add(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "add" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): self + y Returns: ShareTensor. Result of the operation. """ y_share = ShareTensor.sanity_checks(self, y, "add") res = self.apply_function(y_share, "add") return res def sub(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "sub" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): self - y Returns: ShareTensor. Result of the operation. """ y_share = ShareTensor.sanity_checks(self, y, "sub") res = self.apply_function(y_share, "sub") return res def rsub( self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "sub" operation between "y" and "self". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): y - self Returns: ShareTensor. Result of the operation. """ y_share = ShareTensor.sanity_checks(self, y, "sub") res = y_share.apply_function(self, "sub") return res def mul(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "mul" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): self * y Returns: ShareTensor. Result of the operation. """ y = ShareTensor.sanity_checks(self, y, "mul") res = self.apply_function(y, "mul") if self.session_uuid is None and y.session_uuid is None: # We are using a simple share without usig the MPCTensor # In case we used the MPCTensor - the division would have # been done in the protocol res.tensor //= self.fp_encoder.scale return res def xor(self, y: Union[int, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "xor" operation between "self" and "y". Args: y (Union[int, torch.Tensor, "ShareTensor"]): self xor y Returns: ShareTensor: Result of the operation. """ res = self.apply_function(y, "xor") return res def matmul( self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "matmul" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): self @ y. Returns: ShareTensor: Result of the operation. """ y = ShareTensor.sanity_checks(self, y, "matmul") res = self.apply_function(y, "matmul") if self.session_uuid is None and y.session_uuid is None: # We are using a simple share without usig the MPCTensor # In case we used the MPCTensor - the division would have # been done in the protocol res.tensor = res.tensor // self.fp_encoder.scale return res def rmatmul(self, y: torch.Tensor) -> "ShareTensor": """Apply the "rmatmul" operation between "y" and "self". Args: y (torch.Tensor): y @ self Returns: ShareTensor. Result of the operation. """ y = ShareTensor.sanity_checks(self, y, "matmul") return y.matmul(self) def div(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "div" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): Denominator. Returns: ShareTensor: Result of the operation. Raises: ValueError: If y is not an integer or LongTensor. """ if not isinstance(y, (int, torch.LongTensor)): raise ValueError("Div works (for the moment) only with integers!") res = ShareTensor(session_uuid=self.session_uuid, config=self.config) # res = self.apply_function(y, "floordiv") res.tensor = self.tensor // y return res def __gt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool: """Greater than operator. Args: y (Union["ShareTensor", torch.Tensor, int]): Tensor to compare. Returns: bool: Result of the comparison. """ y_share = ShareTensor.sanity_checks(self, y, "gt") res = self.tensor > y_share.tensor return res def __lt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool: """Lower than operator. Args: y (Union["ShareTensor", torch.Tensor, int]): Tensor to compare. Returns: bool: Result of the comparison. """ y_share = ShareTensor.sanity_checks(self, y, "lt") res = self.tensor < y_share.tensor return res def __str__(self) -> str: """Representation. Returns: str: Return the string representation of ShareTensor. """ type_name = type(self).__name__ out = f"[{type_name}]" out = f"{out}\n\t| Session UUID: {self.session_uuid}" out = f"{out}\n\t| {self.fp_encoder}" out = f"{out}\n\t| Data: {self.tensor}" return out def __repr__(self) -> str: """Representation. Returns: String representation. """ return self.__str__() def __eq__(self, other: Any) -> bool: """Equal operator. Check if "self" is equal with another object given a set of attributes to compare. Args: other (Any): Tensor to compare. Returns: bool: True if equal False if not. """ if not (self.tensor == other.tensor).all(): return False if not self.config == other.config: return False if (self.session_uuid and other.session_uuid and self.session_uuid != other.session_uuid): # If both shares have a session_uuid we consider them not equal # else they are return False return True @staticmethod def hook_property(property_name: str) -> Any: """Hook a framework property (only getter). Ex: * if we call "shape" we want to call it on the underlying tensor and return the result * if we call "T" we want to call it on the underlying tensor but we want to wrap it in the same tensor type Args: property_name (str): property to hook Returns: A hooked property """ def property_new_share_tensor_getter(_self: "ShareTensor") -> Any: tensor = getattr(_self.tensor, property_name) res = ShareTensor(session_uuid=_self.session_uuid, config=_self.config) res.tensor = tensor return res def property_getter(_self: "ShareTensor") -> Any: prop = getattr(_self.tensor, property_name) return prop if property_name in PROPERTIES_NEW_SHARE_TENSOR: res = property(property_new_share_tensor_getter, None) else: res = property(property_getter, None) return res @staticmethod def hook_method(method_name: str) -> Callable[..., Any]: """Hook a framework method such that we know how to treat it given that we call it. Ex: * if we call "numel" we want to call it on the underlying tensor and return the result * if we call "unsqueeze" we want to call it on the underlying tensor but we want to wrap it in the same tensor type Args: method_name (str): method to hook Returns: A hooked method """ def method_new_share_tensor(_self: "ShareTensor", *args: List[Any], **kwargs: Dict[Any, Any]) -> Any: method = getattr(_self.tensor, method_name) tensor = method(*args, **kwargs) res = ShareTensor(session_uuid=_self.session_uuid, config=_self.config) res.tensor = tensor return res def method(_self: "ShareTensor", *args: List[Any], **kwargs: Dict[Any, Any]) -> Any: method = getattr(_self.tensor, method_name) res = method(*args, **kwargs) return res if method_name in METHODS_NEW_SHARE_TENSOR: res = method_new_share_tensor else: res = method return res @staticmethod def reconstruct( share_ptrs: List["ShareTensor"], get_shares=False, security_type: str = "semi-honest", ) -> torch.Tensor: """Reconstruct original value from shares. Args: share_ptrs (List[ShareTensor]): List of sharetensors. get_shares (boolean): retrieve shares or reconstructed value. security_type (str): Type of security by protocol. Returns: plaintext/shares (torch.Tensor/List[torch.Tensors]): Plaintext or list of shares. """ def _request_and_get(share_ptr: ShareTensor) -> ShareTensor: """Function used to request and get a share - Duet Setup. Args: share_ptr (ShareTensor): a ShareTensor Returns: ShareTensor. The ShareTensor in local. """ if not islocal(share_ptr): share_ptr.request(block=True) res = share_ptr.get_copy() return res request = _request_and_get request_wrap = parallel_execution(request) args = [[share] for share in share_ptrs] local_shares = request_wrap(args) shares = [share.tensor for share in local_shares] if get_shares: return shares plaintext = sum(shares) return plaintext @staticmethod def distribute_shares(shares: List["ShareTensor"], session: Session): """Distribute a list of shares. Args: shares (List[ShareTensor): list of shares to distribute. session (Session): Session for which those shares were generated Returns: List of ShareTensorPointers. """ rank_to_uuid = session.rank_to_uuid parties = session.parties share_ptrs = [] for rank, share in enumerate(shares): share.session_uuid = rank_to_uuid[rank] party = parties[rank] share_ptrs.append(share.send(party)) return share_ptrs __add__ = add __radd__ = add __sub__ = sub __rsub__ = rsub __mul__ = mul __rmul__ = mul __matmul__ = matmul __rmatmul__ = rmatmul __truediv__ = div __xor__ = xor
class ShareTensor: """This class represents 1 share that a party holds when doing secret sharing. Attributes: Syft Serializable Attributes id (UID): the id to store the session tags (Optional[List[str]): an optional list of strings that are tags used at search description (Optional[str]): an optional string used to describe the session tensor (Any): the value of the share session (Session): keep track from which session this share belongs to fp_encoder (FixedPointEncoder): the encoder used to convert a share from/to fixed point """ __slots__ = { # Populated in Syft "id", "tags", "description", "tensor", "session", "fp_encoder", } def __init__( self, data: Optional[Union[float, int, torch.Tensor]] = None, session: Optional[Session] = None, encoder_base: int = 2, encoder_precision: int = 16, ring_size: int = 2**64, ) -> None: """Initialize ShareTensor. Args: data (Optional[Any]): The share a party holds. Defaults to None session (Optional[Any]): The session from which this shares belongs to. Defaults to None. encoder_base (int): The base for the encoder. Defaults to 2. encoder_precision (int): the precision for the encoder. Defaults to 16. ring_size (int): field used for the operations applied on the shares Defaults to 2**64 """ if session is None: self.session = Session(ring_size=ring_size, ) self.session.config.encoder_precision = encoder_precision self.session.config.encoder_base = encoder_base else: self.session = session encoder_precision = self.session.config.encoder_precision encoder_base = self.session.config.encoder_base # TODO: It looks like the same logic as above self.fp_encoder = FixedPointEncoder(base=encoder_base, precision=encoder_precision) self.tensor: Optional[torch.Tensor] = None if data is not None: tensor_type = self.session.tensor_type self.tensor = self._encode(data).type(tensor_type) def _encode(self, data): return self.fp_encoder.encode(data) def decode(self): """Decode via FixedPrecisionEncoder. Returns: torch.Tensor: Decoded value """ return self._decode() def _decode(self): return self.fp_encoder.decode(self.tensor.type(torch.LongTensor)) @staticmethod def sanity_checks(x: "ShareTensor", y: Union[int, float, torch.Tensor, "ShareTensor"], op_str: str) -> "ShareTensor": """Check the type of "y" and covert it to share if necessary. Args: x (ShareTensor): Typically "self". y (Union[int, float, torch.Tensor, "ShareTensor"]): Tensor to check. op_str (str): String operator. Returns: ShareTensor: the converted y value. """ if not isinstance(y, ShareTensor): y = ShareTensor(data=y, session=x.session) return y def apply_function(self, y: Union["ShareTensor", torch.Tensor, int, float], op_str: str) -> "ShareTensor": """Apply a given operation. Args: y (Union["ShareTensor", torch.Tensor, int, float]): tensor to apply the operator. op_str (str): Operator. Returns: ShareTensor: Result of the operation. """ op = getattr(operator, op_str) if isinstance(y, ShareTensor): value = op(self.tensor, y.tensor) else: value = op(self.tensor, y) res = ShareTensor(session=self.session) res.tensor = value return res def add(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "add" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): self + y Returns: ShareTensor. Result of the operation. """ y_share = ShareTensor.sanity_checks(self, y, "add") res = self.apply_function(y_share, "add") return res def sub(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "sub" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): self - y Returns: ShareTensor. Result of the operation. """ y_share = ShareTensor.sanity_checks(self, y, "sub") res = self.apply_function(y_share, "sub") return res def rsub( self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "sub" operation between "y" and "self". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): y - self Returns: ShareTensor. Result of the operation. """ y_share = ShareTensor.sanity_checks(self, y, "sub") res = y_share.apply_function(self, "sub") return res def mul(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "mul" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): self * y Returns: ShareTensor. Result of the operation. """ y = ShareTensor.sanity_checks(self, y, "mul") res = self.apply_function(y, "mul") if self.session.nr_parties == 0: # We are using a simple share without usig the MPCTensor # In case we used the MPCTensor - the division would have # been done in the protocol res.tensor = res.tensor // self.fp_encoder.scale return res def xor(self, y: Union[int, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "xor" operation between "self" and "y". Args: y (Union[int, torch.Tensor, "ShareTensor"]): self xor y Returns: ShareTensor: Result of the operation. """ res = ShareTensor(session=self.session) if isinstance(y, ShareTensor): res.tensor = self.tensor ^ y.tensor else: res.tensor = self.tensor ^ y return res def matmul( self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "matmul" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): self @ y. Returns: ShareTensor: Result of the operation. """ y = ShareTensor.sanity_checks(self, y, "matmul") res = self.apply_function(y, "matmul") if self.session.nr_parties == 0: # We are using a simple share without usig the MPCTensor # In case we used the MPCTensor - the division would have # been done in the protocol res.tensor = res.tensor // self.fp_encoder.scale return res def rmatmul(self, y: torch.Tensor) -> "ShareTensor": """Apply the "rmatmul" operation between "y" and "self". Args: y (torch.Tensor): y @ self Returns: ShareTensor. Result of the operation. """ y = ShareTensor.sanity_checks(self, y, "matmul") return y.matmul(self) def div(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "div" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): Denominator. Returns: ShareTensor: Result of the operation. Raises: ValueError: If y is not an integer or LongTensor. """ if not isinstance(y, (int, torch.LongTensor)): raise ValueError("Div works (for the moment) only with integers!") res = ShareTensor(session=self.session) res.tensor = self.tensor // y return res def __getattr__(self, attr_name: str) -> Any: """Access to tensor attributes. Args: attr_name (str): Name of the attribute. Returns: Any: Attribute. """ tensor = self.tensor res = getattr(tensor, attr_name) return res def __gt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool: """Greater than operator. Args: y (Union["ShareTensor", torch.Tensor, int]): Tensor to compare. Returns: bool: Result of the comparison. """ y_share = ShareTensor.sanity_checks(self, y, "gt") res = self.tensor > y_share.tensor return res def __lt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool: """Lower than operator. Args: y (Union["ShareTensor", torch.Tensor, int]): Tensor to compare. Returns: bool: Result of the comparison. """ y_share = ShareTensor.sanity_checks(self, y, "lt") res = self.tensor < y_share.tensor return res def __str__(self) -> str: """Representation. Returns: str: Return the string representation of ShareTensor. """ type_name = type(self).__name__ out = f"[{type_name}]" out = f"{out}\n\t| {self.fp_encoder}" out = f"{out}\n\t| Data: {self.tensor}" return out def __repr__(self) -> str: """Representation. Returns: String representation. """ return self.__str__() def __eq__(self, other: Any) -> bool: """Equal operator. Check if "self" is equal with another object given a set of attributes to compare. Args: other (Any): Tensor to compare. Returns: bool: True if equal False if not. """ if not (self.tensor == other.tensor).all(): return False if not (self.session == other.session): return False return True # Forward to tensor methods @property def shape(self) -> Any: """Shape of the tensor. Returns: Any: Shape of the tensor. """ return self.tensor.shape def numel(self, *args: List[Any], **kwargs: Dict[Any, Any]) -> Any: """Total number of elements. Args: *args: Arguments passed to tensor.numel. **kwargs: Keyword arguments passed to tensor.numel. Returns: Any: Total number of elements of the tensor. """ return self.tensor.numel(*args, **kwargs) @property def T(self) -> Any: """Transpose. Returns: Any: ShareTensor transposed. """ res = ShareTensor(session=self.session) res.tensor = self.tensor.T return res def unsqueeze(self, *args: List[Any], **kwargs: Dict[Any, Any]) -> Any: """Tensor with a dimension of size one inserted at the specified position. Args: *args: Arguments to tensor.unsqueeze **kwargs: Keyword arguments passed to tensor.unsqueeze Returns: Any: ShareTensor unsqueezed. References: https://pytorch.org/docs/stable/generated/torch.unsqueeze.html """ tensor = self.tensor.unsqueeze(*args, **kwargs) res = ShareTensor(session=self.session) res.tensor = tensor return res def view(self, *args: List[Any], **kwargs: Dict[Any, Any]) -> Any: """Tensor with the same data but new dimensions/view. Args: *args: Arguments to tensor.view. **kwargs: Keyword arguments passed to tensor.view. Returns: Any: ShareTensor with new view. References: https://pytorch.org/docs/stable/tensors.html#torch.Tensor.view """ tensor = self.tensor.view(*args, **kwargs) res = ShareTensor(session=self.session) res.tensor = tensor return res __add__ = add __radd__ = add __sub__ = sub __rsub__ = rsub __mul__ = mul __rmul__ = mul __matmul__ = matmul __rmatmul__ = rmatmul __truediv__ = div __xor__ = xor
class ReplicatedSharedTensor(metaclass=SyMPCTensor): """RSTensor is used when a party holds more than a single share,required by various protocols. Arguments: shares (Optional[List[Union[float, int, torch.Tensor]]]): Shares list from which RSTensor is created. Attributes: shares: The shares held by the party """ __slots__ = { # Populated in Syft "id", "tags", "description", "shares", "session_uuid", "config", "fp_encoder", "ring_size", } # Used by the SyMPCTensor metaclass METHODS_FORWARD = { "numel", "t", "unsqueeze", "view", "sum", "clone", "repeat" } PROPERTIES_FORWARD = {"T", "shape"} def __init__( self, shares: Optional[List[Union[float, int, torch.Tensor]]] = None, config: Config = Config(encoder_base=2, encoder_precision=16), session_uuid: Optional[UUID] = None, ring_size: int = 2**64, ): """Initialize ReplicatedSharedTensor. Args: shares (Optional[List[Union[float, int, torch.Tensor]]]): Shares list from which RSTensor is created. config (Config): The configuration where we keep the encoder precision and base. session_uuid (Optional[UUID]): Used to keep track of a share that is associated with a remote session ring_size (int): field used for the operations applied on the shares Defaults to 2**64 """ self.session_uuid = session_uuid self.ring_size = ring_size if ring_size in {2, PRIME_NUMBER}: self.config = Config(encoder_base=1, encoder_precision=0) else: self.config = config self.fp_encoder = FixedPointEncoder( base=self.config.encoder_base, precision=self.config.encoder_precision) tensor_type = get_type_from_ring(ring_size) self.shares = [] if shares is not None: self.shares = [ self._encode(share).to(tensor_type) for share in shares ] def _encode(self, data: torch.Tensor) -> torch.Tensor: """Encode via FixedPointEncoder. Args: data (torch.Tensor): Tensor to be encoded Returns: encoded_data (torch.Tensor): Decoded values """ return self.fp_encoder.encode(data) def decode(self) -> List[torch.Tensor]: """Decode via FixedPointEncoder. Returns: List[torch.Tensor]: Decoded values """ return self._decode() def _decode(self) -> List[torch.Tensor]: """Decodes shares list of RSTensor via FixedPointEncoder. Returns: List[torch.Tensor]: Decoded values """ shares = [] shares = [ self.fp_encoder.decode(share.type(torch.LongTensor)) for share in self.shares ] return shares def get_shares(self) -> List[torch.Tensor]: """Get shares. Returns: List[torch.Tensor]: List of shares. """ return self.shares def get_ring_size(self) -> str: """Ring size of tensor. Returns: ring_size (str): Returns ring_size of tensor in string. It is typecasted to string as we cannot serialize 2**64 """ return str(self.ring_size) def get_config(self) -> Dict: """Config of tensor. Returns: config (Dict): returns config of the tensor as dict. """ return dataclasses.asdict(self.config) @staticmethod def addmodprime(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Computes addition(x+y) modulo PRIME_NUMBER constant. Args: x (torch.Tensor): input tensor y (torch.tensor): input tensor Returns: value (torch.Tensor): Result of the operation. Raises: ValueError : If either of the tensors datatype is not torch.uint8 """ if x.dtype != torch.uint8 or y.dtype != torch.uint8: raise ValueError( f"Both tensors x:{x.dtype} y:{y.dtype} should be of torch.uint8 dtype" ) return (x + y) % PRIME_NUMBER @staticmethod def submodprime(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Computes subtraction(x-y) modulo PRIME_NUMBER constant. Args: x (torch.Tensor): input tensor y (torch.tensor): input tensor Returns: value (torch.Tensor): Result of the operation. Raises: ValueError : If either of the tensors datatype is not torch.uint8 """ if x.dtype != torch.uint8 or y.dtype != torch.uint8: raise ValueError( f"Both tensors x:{x.dtype} y:{y.dtype} should be of torch.uint8 dtype" ) # Typecasting is done, as underflow returns a positive number,as it is unsigned. x = x.to(torch.int8) y = y.to(torch.int8) result = (x - y) % PRIME_NUMBER return result.to(torch.uint8) @staticmethod def mulmodprime(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Computes multiplication(x*y) modulo PRIME_NUMBER constant. Args: x (torch.Tensor): input tensor y (torch.tensor): input tensor Returns: value (torch.Tensor): Result of the operation. Raises: ValueError : If either of the tensors datatype is not torch.uint8 """ if x.dtype != torch.uint8 or y.dtype != torch.uint8: raise ValueError( f"Both tensors x:{x.dtype} y:{y.dtype} should be of torch.uint8 dtype" ) # We typecast as multiplication result in 2n bits ,which causes overflow. x = x.to(torch.int16) y = y.to(torch.int16) result = (x * y) % PRIME_NUMBER return result.to(torch.uint8) @staticmethod def get_op(ring_size: int, op_str: str) -> Callable[..., Any]: """Returns method attribute based on ring_size and op_str. Args: ring_size (int): Ring size op_str (str): Operation string. Returns: op (Callable[...,Any]): The operation method for the op_str. Raises: ValueError : If invalid ring size is given as input. """ op = None if ring_size == 2: op = getattr(operator, BINARY_MAP[op_str]) elif ring_size == PRIME_NUMBER: op = getattr(ReplicatedSharedTensor, op_str + "modprime") elif ring_size in RING_SIZE_TO_TYPE.keys(): op = getattr(operator, op_str) else: raise ValueError(f"Invalid ring size: {ring_size}") return op @staticmethod def sanity_checks( x: "ReplicatedSharedTensor", y: Union[int, float, torch.Tensor, "ReplicatedSharedTensor"], ) -> "ReplicatedSharedTensor": """Check the type of "y" and convert it to share if necessary. Args: x (ReplicatedSharedTensor): Typically "self". y (Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]): Tensor to check. Returns: ReplicatedSharedTensor: the converted y value. Raises: ValueError: if both values are shares and they have different uuids ValueError: if both values have different number of shares. ValueError: if both RSTensor have different ring_sizes """ if not isinstance(y, ReplicatedSharedTensor): # As prime ring size is unsigned,we convert negative values. y = y % PRIME_NUMBER if x.ring_size == PRIME_NUMBER else y y = ReplicatedSharedTensor( session_uuid=x.session_uuid, shares=[y], ring_size=x.ring_size, config=x.config, ) elif y.session_uuid and x.session_uuid and y.session_uuid != x.session_uuid: raise ValueError( f"Session UUIDs did not match {x.session_uuid} {y.session_uuid}" ) elif len(x.shares) != len(y.shares): raise ValueError( f"Both RSTensors should have equal number of shares {len(x.shares)} {len(y.shares)}" ) elif x.ring_size != y.ring_size: raise ValueError( f"Both RSTensors should have same ring_size {x.ring_size} {y.ring_size}" ) session_uuid = x.session_uuid if session_uuid is not None: session = sympc.session.get_session(str(x.session_uuid)) else: session = Session(config=x.config, ring_size=x.ring_size) session.nr_parties = 1 return y, session def __apply_public_op(self, y: Union[torch.Tensor, float, int], op_str: str) -> "ReplicatedSharedTensor": """Apply an operation on "self" which is a RSTensor and a public value. Args: y (Union[torch.Tensor, float, int]): Tensor to apply the operation. op_str (str): The operation. Returns: ReplicatedSharedTensor: The operation "op_str" applied on "self" and "y" Raises: ValueError: If "op_str" is not supported. """ y, session = ReplicatedSharedTensor.sanity_checks(self, y) op = ReplicatedSharedTensor.get_op(self.ring_size, op_str) shares = copy.deepcopy(self.shares) if op_str in {"add", "sub"}: if session.rank != 1: idx = (session.nr_parties - session.rank) % session.nr_parties shares[idx] = op(shares[idx], y.shares[0]) else: raise ValueError(f"{op_str} not supported") result = ReplicatedSharedTensor( ring_size=self.ring_size, session_uuid=self.session_uuid, config=self.config, ) result.shares = shares return result def __apply_private_op(self, y: "ReplicatedSharedTensor", op_str: str) -> "ReplicatedSharedTensor": """Apply an operation on 2 RSTensors (secret shared values). Args: y (RelicatedSharedTensor): Tensor to apply the operation op_str (str): The operation Returns: ReplicatedSharedTensor: The operation "op_str" applied on "self" and "y" Raises: ValueError: If "op_str" not supported. """ y, session = ReplicatedSharedTensor.sanity_checks(self, y) op = ReplicatedSharedTensor.get_op(self.ring_size, op_str) shares = [] if op_str in {"add", "sub"}: for x_share, y_share in zip(self.shares, y.shares): shares.append(op(x_share, y_share)) else: raise ValueError(f"{op_str} not supported") result = ReplicatedSharedTensor( ring_size=self.ring_size, session_uuid=self.session_uuid, config=self.config, ) result.shares = shares return result def __apply_op( self, y: Union["ReplicatedSharedTensor", torch.Tensor, float, int], op_str: str, ) -> "ReplicatedSharedTensor": """Apply a given operation ". This function checks if "y" is private or public value. Args: y (Union[ReplicatedSharedTensor,torch.Tensor, float, int]): tensor to apply the operation. op_str (str): the operation. Returns: ReplicatedSharedTensor: the operation "op_str" applied on "self" and "y" """ is_private = isinstance(y, ReplicatedSharedTensor) if is_private: result = self.__apply_private_op(y, op_str) else: result = self.__apply_public_op(y, op_str) return result def add( self, y: Union[int, float, torch.Tensor, "ReplicatedSharedTensor"] ) -> "ReplicatedSharedTensor": """Apply the "add" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]): self + y Returns: ReplicatedSharedTensor: Result of the operation. """ return self.__apply_op(y, "add") def sub( self, y: Union[int, float, torch.Tensor, "ReplicatedSharedTensor"] ) -> "ReplicatedSharedTensor": """Apply the "sub" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]): self - y Returns: ReplicatedSharedTensor: Result of the operation. """ return self.__apply_op(y, "sub") def rsub( self, y: Union[int, float, torch.Tensor, "ReplicatedSharedTensor"] ) -> "ReplicatedSharedTensor": """Apply the "sub" operation between "y" and "self". Args: y (Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]): y -self Returns: ReplicatedSharedTensor: Result of the operation. """ return self.__apply_op(y, "sub") def mul( self, y: Union[int, float, torch.Tensor, "ReplicatedSharedTensor"] ) -> "ReplicatedSharedTensor": """Apply the "mul" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]): self*y Returns: ReplicatedSharedTensor: Result of the operation. Raises: ValueError: Raised when private mul is performed parties!=3. """ y_tensor, session = self.sanity_checks(self, y) is_private = isinstance(y, ReplicatedSharedTensor) op_str = "mul" op = ReplicatedSharedTensor.get_op(self.ring_size, op_str) if is_private: if session.nr_parties == 3: from sympc.protocol import Falcon result = [ Falcon.multiplication_protocol(self, y_tensor, op_str) ] else: raise ValueError( "Private mult between ReplicatedSharedTensors is allowed only for 3 parties" ) else: result = [op(share, y_tensor.shares[0]) for share in self.shares] tensor = ReplicatedSharedTensor(ring_size=self.ring_size, session_uuid=self.session_uuid, config=self.config) tensor.shares = result return tensor def matmul( self, y: Union[int, float, torch.Tensor, "ReplicatedSharedTensor"] ) -> "ReplicatedSharedTensor": """Apply the "matmul" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]): self@y Returns: ReplicatedSharedTensor: Result of the operation. Raises: ValueError: Raised when private matmul is performed parties!=3. """ y_tensor, session = self.sanity_checks(self, y) is_private = isinstance(y, ReplicatedSharedTensor) op_str = "matmul" if is_private: if session.nr_parties == 3: from sympc.protocol import Falcon result = [ Falcon.multiplication_protocol(self, y_tensor, op_str) ] else: raise ValueError( "Private matmul between ReplicatedSharedTensors is allowed only for 3 parties" ) else: result = [ operator.matmul(share, y_tensor.shares[0]) for share in self.shares ] tensor = ReplicatedSharedTensor(ring_size=self.ring_size, session_uuid=self.session_uuid, config=self.config) tensor.shares = result return tensor def truediv(self, y: Union[int, torch.Tensor]) -> "ReplicatedSharedTensor": """Apply the "div" operation between "self" and "y". Args: y (Union[int , torch.Tensor]): Denominator. Returns: ReplicatedSharedTensor: Result of the operation. Raises: ValueError: If y is not an integer or LongTensor. """ if not isinstance(y, (int, torch.LongTensor)): raise ValueError( "Div works (for the moment) only with integers and LongTensor!" ) res = ReplicatedSharedTensor(session_uuid=self.session_uuid, config=self.config, ring_size=self.ring_size) res.shares = [share // y for share in self.shares] return res def rshift(self, y: int) -> "ReplicatedSharedTensor": """Apply the "rshift" operation to "self". Args: y (int): shift value Returns: ReplicatedSharedTensor: Result of the operation. Raises: ValueError: If y is not an integer. ValueError : If invalid shift value is provided. """ if not isinstance(y, int): raise ValueError("Right Shift works only with integers!") ring_bits = get_nr_bits(self.ring_size) if y < 0 or y > ring_bits - 1: raise ValueError( f"Invalid value for right shift: {y}, must be in range:[0,{ring_bits-1}]" ) res = ReplicatedSharedTensor(session_uuid=self.session_uuid, config=self.config, ring_size=self.ring_size) res.shares = [share >> y for share in self.shares] return res def bit_extraction(self, pos: int = 0) -> "ReplicatedSharedTensor": """Extracts the bit at the specified position. Args: pos (int): position to extract bit. Returns: ReplicatedSharedTensor : extracted bits at specific position. Raises: ValueError: If invalid position is provided. """ ring_bits = get_nr_bits(self.ring_size) if pos < 0 or pos > ring_bits - 1: raise ValueError( f"Invalid position for bit_extraction: {pos}, must be in range:[0,{ring_bits-1}]" ) shares = [] # logical shift bit_mask = torch.ones(self.shares[0].shape, dtype=self.shares[0].dtype) << pos shares = [share & bit_mask for share in self.shares] rst = ReplicatedSharedTensor( shares=shares, session_uuid=self.session_uuid, config=Config(encoder_base=1, encoder_precision=0), ring_size=2, ) return rst def rmatmul(self, y): """Apply the "rmatmul" operation between "y" and "self". Args: y: self@y Raises: NotImplementedError: Raised when implementation not present """ raise NotImplementedError def xor( self, y: Union[int, torch.Tensor, "ReplicatedSharedTensor"] ) -> "ReplicatedSharedTensor": """Apply the "xor" operation between "self" and "y". Args: y: public bit Returns: ReplicatedSharedTensor: Result of the operation. Raises: ValueError : If ring size is invalid. """ if self.ring_size == 2: return self + y elif self.ring_size in RING_SIZE_TO_TYPE: return self + y - (self * y * 2) else: raise ValueError( f"The ring_size {self.ring_size} is not supported.") def lt(self, y): """Lower than operator. Args: y: self<y Raises: NotImplementedError: Raised when implementation not present """ raise NotImplementedError def gt(self, y): """Greater than operator. Args: y: self>y Raises: NotImplementedError: Raised when implementation not present """ raise NotImplementedError def eq(self, y: Any) -> bool: """Equal operator. Check if "self" is equal with another object given a set of attributes to compare. Args: y (Any): Object to compare Returns: bool: True if equal False if not. """ if not (torch.cat(self.shares) == torch.cat(y.shares)).all(): return False if self.config != y.config: return False if self.session_uuid and y.session_uuid and self.session_uuid != y.session_uuid: return False if self.ring_size != y.ring_size: return False return True def __getitem__(self, key: int) -> torch.Tensor: """Allows to subset shares. Args: key (int): The share to be retrieved. Returns: share (torch.Tensor): Returned share. """ return self.shares[key] def __setitem__(self, key: int, newvalue: torch.Tensor) -> None: """Allows to set share value to new value. Args: key (int): The share to be retrieved. newvalue (torch.Tensor): New value of share. """ self.shares[key] = newvalue def ne(self, y): """Not Equal operator. Args: y: self!=y Raises: NotImplementedError: Raised when implementation not present """ raise NotImplementedError @staticmethod def shares_sum(shares: List[torch.Tensor], ring_size: int) -> torch.Tensor: """Returns sum of tensors based on ring_size. Args: shares (List[torch.Tensor]) : List of tensors. ring_size (int): Ring size of share associated with the tensors. Returns: value (torch.Tensor): sum of the tensors. """ if ring_size == 2: return reduce(lambda x, y: x ^ y, shares) elif ring_size == PRIME_NUMBER: return reduce(ReplicatedSharedTensor.addmodprime, shares) else: return sum(shares) @staticmethod def _request_and_get( share_ptr: "ReplicatedSharedTensor", ) -> "ReplicatedSharedTensor": """Function used to request and get a share - Duet Setup. Args: share_ptr (ReplicatedSharedTensor): input ReplicatedSharedTensor Returns: ReplicatedSharedTensor : The ReplicatedSharedTensor in local. """ if not ispointer(share_ptr): return share_ptr elif not islocal(share_ptr): share_ptr.request(block=True) res = share_ptr.get_copy() return res @staticmethod def __reconstruct_semi_honest( share_ptrs: List["ReplicatedSharedTensor"], get_shares: bool = False, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Reconstruct value from shares. Args: share_ptrs (List[ReplicatedSharedTensor]): List of RSTensor pointers. get_shares (bool): Retrieve only shares. Returns: reconstructed_value (torch.Tensor): Reconstructed value. """ request = ReplicatedSharedTensor._request_and_get request_wrap = parallel_execution(request) args = [[share] for share in share_ptrs[:2]] local_shares = request_wrap(args) shares = [local_shares[0].shares[0]] shares.extend(local_shares[1].shares) if get_shares: return shares ring_size = local_shares[0].ring_size return ReplicatedSharedTensor.shares_sum(shares, ring_size) @staticmethod def __reconstruct_malicious( share_ptrs: List["ReplicatedSharedTensor"], get_shares: bool = False, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Reconstruct value from shares. Args: share_ptrs (List[ReplicatedSharedTensor]): List of RSTensor pointers. get_shares (bool): Retrieve only shares. Returns: reconstructed_value (torch.Tensor): Reconstructed value. Raises: ValueError: When parties share values are not equal. """ nparties = len(share_ptrs) # Get shares from all parties request = ReplicatedSharedTensor._request_and_get request_wrap = parallel_execution(request) args = [[share] for share in share_ptrs] local_shares = request_wrap(args) ring_size = local_shares[0].ring_size shares_sum = ReplicatedSharedTensor.shares_sum all_shares = [rst.shares for rst in local_shares] # reconstruct shares from all parties and verify value = None for party_rank in range(nparties): tensor = shares_sum( [all_shares[party_rank][0]] + all_shares[(party_rank + 1) % (nparties)], ring_size, ) if value is None: value = tensor elif (tensor != value).any(): raise ValueError( "Reconstruction values from all parties are not equal.") if get_shares: return all_shares return value @staticmethod def reconstruct( share_ptrs: List["ReplicatedSharedTensor"], get_shares: bool = False, security_type: str = "semi-honest", ) -> Union[torch.Tensor, List[torch.Tensor]]: """Reconstruct value from shares. Args: share_ptrs (List[ReplicatedSharedTensor]): List of RSTensor pointers. security_type (str): Type of security followed by protocol. get_shares (bool): Retrieve only shares. Returns: reconstructed_value (torch.Tensor): Reconstructed value. Raises: ValueError: Invalid security type. ValueError : SharePointers not provided. """ if not len(share_ptrs): raise ValueError( "Share pointers must be provided for reconstruction.") if security_type == "malicious": return ReplicatedSharedTensor.__reconstruct_malicious( share_ptrs, get_shares) elif security_type == "semi-honest": return ReplicatedSharedTensor.__reconstruct_semi_honest( share_ptrs, get_shares) raise ValueError("Invalid security Type") @staticmethod def distribute_shares_to_party( shares: List[Union[ShareTensor, torch.Tensor]], party_rank: int, session: Session, ring_size: int, config: Config, ) -> "ReplicatedSharedTensor": """Distributes shares to party. Args: shares (List[Union[ShareTensor,torch.Tensor]]): Shares to be distributed. party_rank (int): Rank of party. session (Session): Current session ring_size (int): Ring size of tensor to distribute config (Config): The configuration(base,precision) of the tensor. Returns: tensor (ReplicatedSharedTensor): Tensor with shares Raises: TypeError: Invalid share class. """ party = session.parties[party_rank] nshares = session.nr_parties - 1 party_shares = [] for share_index in range(party_rank, party_rank + nshares): share = shares[share_index % (nshares + 1)] if isinstance(share, torch.Tensor): party_shares.append(share) elif isinstance(share, ShareTensor): party_shares.append(share.tensor) else: raise TypeError(f"{type(share)} is an invalid share class") tensor = ReplicatedSharedTensor( session_uuid=session.rank_to_uuid[party_rank], config=config, ring_size=ring_size, ) tensor.shares = party_shares return tensor.send(party) @staticmethod def distribute_shares( shares: List[Union[ShareTensor, torch.Tensor]], session: Session, ring_size: Optional[int] = None, config: Optional[Config] = None, ) -> List["ReplicatedSharedTensor"]: """Distribute a list of shares. Args: shares (List[ShareTensor): list of shares to distribute. session (Session): Session. ring_size (int): ring_size the shares belong to. config (Config): The configuration(base,precision) of the tensor. Returns: List of ReplicatedSharedTensors. Raises: TypeError: when Datatype of shares is invalid. """ if not isinstance(shares, (list, tuple)): raise TypeError( "Shares to be distributed should be a list of shares") if len(shares) != session.nr_parties: return ValueError( "Number of shares to be distributed should be same as number of parties" ) if ring_size is None: ring_size = session.ring_size if config is None: config = session.config args = [[shares, party_rank, session, ring_size, config] for party_rank in range(session.nr_parties)] return [ ReplicatedSharedTensor.distribute_shares_to_party(*arg) for arg in args ] @staticmethod def hook_property(property_name: str) -> Any: """Hook a framework property (only getter). Ex: * if we call "shape" we want to call it on the underlying tensor and return the result * if we call "T" we want to call it on the underlying tensor but we want to wrap it in the same tensor type Args: property_name (str): property to hook Returns: A hooked property """ def property_new_rs_tensor_getter( _self: "ReplicatedSharedTensor") -> Any: shares = [] for share in _self.shares: tensor = getattr(share, property_name) shares.append(tensor) res = ReplicatedSharedTensor( session_uuid=_self.session_uuid, config=_self.config, ring_size=_self.ring_size, ) res.shares = shares return res def property_getter(_self: "ReplicatedSharedTensor") -> Any: prop = getattr(_self.shares[0], property_name) return prop if property_name in PROPERTIES_NEW_RS_TENSOR: res = property(property_new_rs_tensor_getter, None) else: res = property(property_getter, None) return res @staticmethod def hook_method(method_name: str) -> Callable[..., Any]: """Hook a framework method such that we know how to treat it given that we call it. Ex: * if we call "numel" we want to call it on the underlying tensor and return the result * if we call "unsqueeze" we want to call it on the underlying tensor but we want to wrap it in the same tensor type Args: method_name (str): method to hook Returns: A hooked method """ def method_new_rs_tensor(_self: "ReplicatedSharedTensor", *args: List[Any], **kwargs: Dict[Any, Any]) -> Any: shares = [] for share in _self.shares: tensor = getattr(share, method_name)(*args, **kwargs) shares.append(tensor) res = ReplicatedSharedTensor( session_uuid=_self.session_uuid, config=_self.config, ring_size=_self.ring_size, ) res.shares = shares return res def method(_self: "ReplicatedSharedTensor", *args: List[Any], **kwargs: Dict[Any, Any]) -> Any: method = getattr(_self.shares[0], method_name) res = method(*args, **kwargs) return res if method_name in METHODS_NEW_RS_TENSOR: res = method_new_rs_tensor else: res = method return res __add__ = add __radd__ = add __sub__ = sub __rsub__ = rsub __mul__ = mul __rmul__ = mul __matmul__ = matmul __rmatmul__ = rmatmul __truediv__ = truediv __floordiv__ = truediv __xor__ = xor __eq__ = eq __rshift__ = rshift
class ShareTensor: """This class represents 1 share that a party holds when doing secret sharing. Arguments: data (Optional[Any]): the share a party holds session (Optional[Any]): the session from which this shares belongs to encoder_base (int): the base for the encoder encoder_precision (int): the precision for the encoder ring_size (int): field used for the operations applied on the shares Attributes: Syft Serializable Attributes id (UID): the id to store the session tags (Optional[List[str]): an optional list of strings that are tags used at search description (Optional[str]): an optional string used to describe the session tensor (Any): the value of the share session (Session): keep track from which session this share belongs to fp_encoder (FixedPointEncoder): the encoder used to convert a share from/to fixed point """ __slots__ = { # Populated in Syft "id", "tags", "description", "tensor", "session", "fp_encoder", } def __init__( self, data: Optional[Union[float, int, torch.Tensor]] = None, session: Optional[Session] = None, encoder_base: int = 2, encoder_precision: int = 16, ring_size: int = 2**64, ) -> None: """Initializer for the ShareTensor.""" if session is None: self.session = Session(ring_size=ring_size, ) self.session.config.encoder_precision = encoder_precision self.session.config.encoder_base = encoder_base else: self.session = session encoder_precision = self.session.config.encoder_precision encoder_base = self.session.config.encoder_base # TODO: It looks like the same logic as above self.fp_encoder = FixedPointEncoder(base=encoder_base, precision=encoder_precision) self.tensor: Optional[torch.Tensor] = None if data is not None: tensor_type = self.session.tensor_type self.tensor = self._encode(data).type(tensor_type) def _encode(self, data): return self.fp_encoder.encode(data) def decode(self): return self._decode() def _decode(self): return self.fp_encoder.decode(self.tensor.type(torch.LongTensor)) @staticmethod def sanity_checks(x: "ShareTensor", y: Union[int, float, torch.Tensor, "ShareTensor"], op_str: str) -> "ShareTensor": """Check the type of "y" and convert it to a share if necessary. :return: the y value :rtype: ShareTensor, int or Integer type Tensor """ if not isinstance(y, ShareTensor): y = ShareTensor(data=y, session=x.session) return y def apply_function(self, y: Union["ShareTensor", torch.Tensor, int, float], op_str: str) -> "ShareTensor": """Apply a given operation. :return: the result of applying "op_str" on "self" and y :rtype: ShareTensor """ op = getattr(operator, op_str) if isinstance(y, ShareTensor): value = op(self.tensor, y.tensor) else: value = op(self.tensor, y) res = ShareTensor(session=self.session) res.tensor = value return res def add(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "add" operation between "self" and "y". :return: self + y :rtype: ShareTensor """ y_share = ShareTensor.sanity_checks(self, y, "add") res = self.apply_function(y_share, "add") return res def sub(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "sub" operation between "self" and "y". :return: self - y :rtype: ShareTensor """ y_share = ShareTensor.sanity_checks(self, y, "sub") res = self.apply_function(y_share, "sub") return res def rsub( self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "sub" operation between "y" and "self". :return: y - self :rtype: ShareTensor """ y_share = ShareTensor.sanity_checks(self, y, "sub") res = y_share.apply_function(self, "sub") return res def mul(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "mul" operation between "self" and "y". :return: self * y :rtype: ShareTensor """ y = ShareTensor.sanity_checks(self, y, "mul") res = self.apply_function(y, "mul") if self.session.nr_parties == 0: # We are using a simple share without usig the MPCTensor # In case we used the MPCTensor - the division would have # been done in the protocol res.tensor = res.tensor // self.fp_encoder.scale return res def matmul( self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "matmul" operation between "self" and "y". :return: self @ y :rtype: ShareTensor """ y = ShareTensor.sanity_checks(self, y, "matmul") res = self.apply_function(y, "matmul") if self.session.nr_parties == 0: # We are using a simple share without usig the MPCTensor # In case we used the MPCTensor - the division would have # been done in the protocol res.tensor = res.tensor // self.fp_encoder.scale return res def rmatmul(self, y: torch.Tensor) -> "ShareTensor": """Apply the reversed "matmul" operation between "self" and "y". :return: y @ self :rtype: ShareTensor """ y = ShareTensor.sanity_checks(self, y, "matmul") return y.matmul(self) def div(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "div" operation between "self" and "y". Currently, NotImplemented. :return: self / y :rtype: ShareTensor """ if not isinstance(y, (int, torch.LongTensor)): raise ValueError("Div works (for the moment) only with integers!") res = ShareTensor(session=self.session) res.tensor = self.tensor // y return res def __getattr__(self, attr_name: str) -> Any: """Get the attribute from the ShareTensor. If the attribute is not found at the ShareTensor level, the it would look for in the the "tensor". :return: the attribute value :rtype: Anything """ # Default to some tensor specific attributes like # size, shape, etc. tensor = self.tensor return getattr(tensor, attr_name) def __gt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool: """Check if "self" is bigger than "y". :return: self > y :rtype: bool """ y_share = ShareTensor.sanity_checks(self, y, "gt") res = self.tensor > y_share.tensor return res def __lt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool: """Check if "self" is less than "y". :return: self < y :rtype: bool """ y_share = ShareTensor.sanity_checks(self, y, "lt") res = self.tensor < y_share.tensor return res def __str__(self) -> str: """Return the string representation of ShareTensor.""" type_name = type(self).__name__ out = f"[{type_name}]" out = f"{out}\n\t| {self.fp_encoder}" out = f"{out}\n\t| Data: {self.tensor}" return out def __repr__(self): return self.__str__() def __eq__(self, other: Any) -> bool: """Check if "self" is equal with another object given a set of attributes to compare. :return: if self and other are equal :rtype: bool """ if not (self.tensor == other.tensor).all(): return False if not (self.session == other.session): return False return True __add__ = add __radd__ = add __sub__ = sub __rsub__ = rsub __mul__ = mul __rmul__ = mul __matmul__ = matmul __rmatmul__ = rmatmul __truediv__ = div
class ReplicatedSharedTensor(metaclass=SyMPCTensor): """RSTensor is used when a party holds more than a single share,required by various protocols. Arguments: shares (Optional[List[Union[float, int, torch.Tensor]]]): Shares list from which RSTensor is created. Attributes: shares: The shares held by the party """ AUTOGRAD_IS_ON: bool = True # Used by the SyMPCTensor metaclass METHODS_FORWARD = {"numel", "t", "unsqueeze", "view", "sum", "clone"} PROPERTIES_FORWARD = {"T"} def __init__( self, shares: Optional[List[Union[float, int, torch.Tensor]]] = None, config: Config = Config(encoder_base=2, encoder_precision=16), session_uuid: Optional[UUID] = None, ring_size: int = 2**64, ): """Initialize ReplicatedSharedTensor. Args: shares (Optional[List[Union[float, int, torch.Tensor]]]): Shares list from which RSTensor is created. config (Config): The configuration where we keep the encoder precision and base. session_uuid (Optional[UUID]): Used to keep track of a share that is associated with a remote session ring_size (int): field used for the operations applied on the shares Defaults to 2**64 """ self.session_uuid = session_uuid self.ring_size = ring_size self.config = config self.fp_encoder = FixedPointEncoder(base=config.encoder_base, precision=config.encoder_precision) tensor_type = get_type_from_ring(ring_size) self.shares = [] if shares is not None: for i in range(len(shares)): self.shares.append(self._encode(shares[i]).to(tensor_type)) def _encode(self, data): return self.fp_encoder.encode(data) def decode(self): """Decode via FixedPointEncoder. Returns: List[torch.Tensor]: Decoded values """ return self._decode() def _decode(self): """Decodes shares list of RSTensor via FixedPointEncoder. Returns: List[torch.Tensor]: Decoded values """ shares = [] for i in range(len(self.shares)): tensor = self.fp_encoder.decode(self.shares[i].type( torch.LongTensor)) shares.append(tensor) return shares def add(self, y): """Apply the "add" operation between "self" and "y". Args: y: self+y """ def sub(self, y): """Apply the "sub" operation between "self" and "y". Args: y: self-y """ def rsub(self, y): """Apply the "sub" operation between "y" and "self". Args: y: self-y """ def mul(self, y): """Apply the "mul" operation between "self" and "y". Args: y: self*y """ def truediv(self, y): """Apply the "div" operation between "self" and "y". Args: y: self/y """ def matmul(self, y): """Apply the "matmul" operation between "self" and "y". Args: y: self@y """ def rmatmul(self, y): """Apply the "rmatmul" operation between "y" and "self". Args: y: self@y """ def xor(self, y): """Apply the "xor" operation between "self" and "y". Args: y: self^y """ def lt(self, y): """Lower than operator. Args: y: self<y """ def gt(self, y): """Greater than operator. Args: y: self>y """ def eq(self, y): """Equal operator. Args: y: self==y """ def ne(self, y): """Not Equal operator. Args: y: self!=y """ @staticmethod def hook_property(property_name: str) -> Any: """Hook a framework property (only getter). Ex: * if we call "shape" we want to call it on the underlying tensor and return the result * if we call "T" we want to call it on the underlying tensor but we want to wrap it in the same tensor type Args: property_name (str): property to hook Returns: A hooked property """ def property_new_rs_tensor_getter( _self: "ReplicatedSharedTensor") -> Any: shares = [] for i in range(len(_self.shares)): tensor = getattr(_self.shares[i], property_name) shares.append(tensor) res = ReplicatedSharedTensor(session_uuid=_self.session_uuid, config=_self.config) res.shares = shares return res def property_getter(_self: "ReplicatedSharedTensor") -> Any: prop = getattr(_self.shares[0], property_name) return prop if property_name in PROPERTIES_NEW_RS_TENSOR: res = property(property_new_rs_tensor_getter, None) else: res = property(property_getter, None) return res @staticmethod def hook_method(method_name: str) -> Callable[..., Any]: """Hook a framework method such that we know how to treat it given that we call it. Ex: * if we call "numel" we want to call it on the underlying tensor and return the result * if we call "unsqueeze" we want to call it on the underlying tensor but we want to wrap it in the same tensor type Args: method_name (str): method to hook Returns: A hooked method """ def method_new_rs_tensor(_self: "ReplicatedSharedTensor", *args: List[Any], **kwargs: Dict[Any, Any]) -> Any: shares = [] for i in range(len(_self.shares)): tensor = getattr(_self.shares[i], method_name)(*args, **kwargs) shares.append(tensor) res = ReplicatedSharedTensor(session_uuid=_self.session_uuid, config=_self.config) res.shares = shares return res def method(_self: "ReplicatedSharedTensor", *args: List[Any], **kwargs: Dict[Any, Any]) -> Any: method = getattr(_self.shares[0], method_name) res = method(*args, **kwargs) return res if method_name in METHODS_NEW_RS_TENSOR: res = method_new_rs_tensor else: res = method return res __add__ = add __radd__ = add __sub__ = sub __rsub__ = rsub __mul__ = mul __rmul__ = mul __matmul__ = matmul __rmatmul__ = rmatmul __truediv__ = truediv __xor__ = xor
class ShareTensor(metaclass=SyMPCTensor): """Single Share representation. Arguments: data (Optional[Any]): the share a party holds session (Optional[Any]): the session from which this shares belongs to encoder_base (int): the base for the encoder encoder_precision (int): the precision for the encoder ring_size (int): field used for the operations applied on the shares Attributes: Syft Serializable Attributes id (UID): the id to store the session tags (Optional[List[str]): an optional list of strings that are tags used at search description (Optional[str]): an optional string used to describe the session tensor (Any): the value of the share session (Session): keep track from which session this share belongs to fp_encoder (FixedPointEncoder): the encoder used to convert a share from/to fixed point """ __slots__ = { # Populated in Syft "id", "tags", "description", "tensor", "session", "fp_encoder", } # Used by the SyMPCTensor metaclass METHODS_FORWARD: Set[str] = { "numel", "unsqueeze", "t", "view", "sum", "clone", "flatten", "reshape", } PROPERTIES_FORWARD: Set[str] = {"T", "shape"} def __init__( self, data: Optional[Union[float, int, torch.Tensor]] = None, session: Optional[Session] = None, encoder_base: int = 2, encoder_precision: int = 16, ring_size: int = 2**64, ) -> None: """Initialize ShareTensor. Args: data (Optional[Any]): The share a party holds. Defaults to None session (Optional[Any]): The session from which this shares belongs to. Defaults to None. encoder_base (int): The base for the encoder. Defaults to 2. encoder_precision (int): the precision for the encoder. Defaults to 16. ring_size (int): field used for the operations applied on the shares Defaults to 2**64 """ if session is None: self.session = Session(ring_size=ring_size, ) self.session.config.encoder_precision = encoder_precision self.session.config.encoder_base = encoder_base else: self.session = session encoder_precision = self.session.config.encoder_precision encoder_base = self.session.config.encoder_base # TODO: It looks like the same logic as above self.fp_encoder = FixedPointEncoder(base=encoder_base, precision=encoder_precision) self.tensor: Optional[torch.Tensor] = None if data is not None: tensor_type = self.session.tensor_type self.tensor = self._encode(data).type(tensor_type) def _encode(self, data): return self.fp_encoder.encode(data) def decode(self): """Decode via FixedPrecisionEncoder. Returns: torch.Tensor: Decoded value """ return self._decode() def _decode(self): return self.fp_encoder.decode(self.tensor.type(torch.LongTensor)) @staticmethod def sanity_checks(x: "ShareTensor", y: Union[int, float, torch.Tensor, "ShareTensor"], op_str: str) -> "ShareTensor": """Check the type of "y" and covert it to share if necessary. Args: x (ShareTensor): Typically "self". y (Union[int, float, torch.Tensor, "ShareTensor"]): Tensor to check. op_str (str): String operator. Returns: ShareTensor: the converted y value. """ if not isinstance(y, ShareTensor): y = ShareTensor(data=y, session=x.session) return y def apply_function(self, y: Union["ShareTensor", torch.Tensor, int, float], op_str: str) -> "ShareTensor": """Apply a given operation. Args: y (Union["ShareTensor", torch.Tensor, int, float]): tensor to apply the operator. op_str (str): Operator. Returns: ShareTensor: Result of the operation. """ op = getattr(operator, op_str) if isinstance(y, ShareTensor): value = op(self.tensor, y.tensor) else: value = op(self.tensor, y) res = ShareTensor(session=self.session) res.tensor = value return res def add(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "add" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): self + y Returns: ShareTensor. Result of the operation. """ y_share = ShareTensor.sanity_checks(self, y, "add") res = self.apply_function(y_share, "add") return res def sub(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "sub" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): self - y Returns: ShareTensor. Result of the operation. """ y_share = ShareTensor.sanity_checks(self, y, "sub") res = self.apply_function(y_share, "sub") return res def rsub( self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "sub" operation between "y" and "self". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): y - self Returns: ShareTensor. Result of the operation. """ y_share = ShareTensor.sanity_checks(self, y, "sub") res = y_share.apply_function(self, "sub") return res def mul(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "mul" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): self * y Returns: ShareTensor. Result of the operation. """ y = ShareTensor.sanity_checks(self, y, "mul") res = self.apply_function(y, "mul") if self.session.nr_parties == 0: # We are using a simple share without usig the MPCTensor # In case we used the MPCTensor - the division would have # been done in the protocol res.tensor //= self.fp_encoder.scale return res def xor(self, y: Union[int, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "xor" operation between "self" and "y". Args: y (Union[int, torch.Tensor, "ShareTensor"]): self xor y Returns: ShareTensor: Result of the operation. """ res = ShareTensor(session=self.session) if isinstance(y, ShareTensor): res.tensor = self.tensor ^ y.tensor else: res.tensor = self.tensor ^ y return res def matmul( self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "matmul" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): self @ y. Returns: ShareTensor: Result of the operation. """ y = ShareTensor.sanity_checks(self, y, "matmul") res = self.apply_function(y, "matmul") if self.session.nr_parties == 0: # We are using a simple share without usig the MPCTensor # In case we used the MPCTensor - the division would have # been done in the protocol res.tensor = res.tensor // self.fp_encoder.scale return res def rmatmul(self, y: torch.Tensor) -> "ShareTensor": """Apply the "rmatmul" operation between "y" and "self". Args: y (torch.Tensor): y @ self Returns: ShareTensor. Result of the operation. """ y = ShareTensor.sanity_checks(self, y, "matmul") return y.matmul(self) def div(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "div" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): Denominator. Returns: ShareTensor: Result of the operation. Raises: ValueError: If y is not an integer or LongTensor. """ if not isinstance(y, (int, torch.LongTensor)): raise ValueError("Div works (for the moment) only with integers!") res = ShareTensor(session=self.session) res.tensor = self.tensor // y return res def __gt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool: """Greater than operator. Args: y (Union["ShareTensor", torch.Tensor, int]): Tensor to compare. Returns: bool: Result of the comparison. """ y_share = ShareTensor.sanity_checks(self, y, "gt") res = self.tensor > y_share.tensor return res def __lt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool: """Lower than operator. Args: y (Union["ShareTensor", torch.Tensor, int]): Tensor to compare. Returns: bool: Result of the comparison. """ y_share = ShareTensor.sanity_checks(self, y, "lt") res = self.tensor < y_share.tensor return res def __str__(self) -> str: """Representation. Returns: str: Return the string representation of ShareTensor. """ type_name = type(self).__name__ out = f"[{type_name}]" out = f"{out}\n\t| {self.fp_encoder}" out = f"{out}\n\t| Data: {self.tensor}" return out def __repr__(self) -> str: """Representation. Returns: String representation. """ return self.__str__() def __eq__(self, other: Any) -> bool: """Equal operator. Check if "self" is equal with another object given a set of attributes to compare. Args: other (Any): Tensor to compare. Returns: bool: True if equal False if not. """ if not (self.tensor == other.tensor).all(): return False if not (self.session == other.session): return False return True @staticmethod def hook_property(property_name: str) -> Any: """Hook a framework property (only getter). Ex: * if we call "shape" we want to call it on the underlying tensor and return the result * if we call "T" we want to call it on the underlying tensor but we want to wrap it in the same tensor type Args: property_name (str): property to hook Returns: A hooked property """ def property_new_share_tensor_getter(_self: "ShareTensor") -> Any: tensor = getattr(_self.tensor, property_name) res = ShareTensor(session=_self.session) res.tensor = tensor return res def property_getter(_self: "ShareTensor") -> Any: prop = getattr(_self.tensor, property_name) return prop if property_name in PROPERTIES_NEW_SHARE_TENSOR: res = property(property_new_share_tensor_getter, None) else: res = property(property_getter, None) return res @staticmethod def hook_method(method_name: str) -> Callable[..., Any]: """Hook a framework method such that we know how to treat it given that we call it. Ex: * if we call "numel" we want to call it on the underlying tensor and return the result * if we call "unsqueeze" we want to call it on the underlying tensor but we want to wrap it in the same tensor type Args: method_name (str): method to hook Returns: A hooked method """ def method_new_share_tensor(_self: "ShareTensor", *args: List[Any], **kwargs: Dict[Any, Any]) -> Any: method = getattr(_self.tensor, method_name) tensor = method(*args, **kwargs) res = ShareTensor(session=_self.session) res.tensor = tensor return res def method(_self: "ShareTensor", *args: List[Any], **kwargs: Dict[Any, Any]) -> Any: method = getattr(_self.tensor, method_name) res = method(*args, **kwargs) return res if method_name in METHODS_NEW_SHARE_TENSOR: res = method_new_share_tensor else: res = method return res __add__ = add __radd__ = add __sub__ = sub __rsub__ = rsub __mul__ = mul __rmul__ = mul __matmul__ = matmul __rmatmul__ = rmatmul __truediv__ = div __xor__ = xor