def __init__( self, data: Optional[Union[float, int, torch.Tensor]] = None, session: Optional[Session] = None, encoder_base: int = 2, encoder_precision: int = 16, ring_size: int = 2**64, ) -> None: """Initializer for the ShareTensor.""" if session is None: self.session = Session(ring_size=ring_size, ) self.session.config.encoder_precision = encoder_precision self.session.config.encoder_base = encoder_base else: self.session = session encoder_precision = self.session.config.encoder_precision encoder_base = self.session.config.encoder_base # TODO: It looks like the same logic as above self.fp_encoder = FixedPointEncoder(base=encoder_base, precision=encoder_precision) self.tensor: Optional[torch.Tensor] = None if data is not None: tensor_type = self.session.tensor_type self.tensor = self._encode(data).type(tensor_type)
def __init__( self, data: Optional[Union[float, int, torch.Tensor]] = None, config: Config = Config(encoder_base=2, encoder_precision=16), session_uuid: Optional[UUID] = None, ring_size: int = 2**64, ) -> None: """Initialize ShareTensor. Args: data (Optional[Any]): The share a party holds. Defaults to None config (Config): The configuration where we keep the encoder precision and base. session_uuid (Optional[UUID]): Used to keep track of a share that is associated with a remote session ring_size (int): field used for the operations applied on the shares Defaults to 2**64 """ self.session_uuid = session_uuid self.ring_size = ring_size self.config = config self.fp_encoder = FixedPointEncoder(base=config.encoder_base, precision=config.encoder_precision) self.tensor: Optional[torch.Tensor] = None if data is not None: tensor_type = get_type_from_ring(ring_size) self.tensor = self._encode(data).to(tensor_type)
def __init__( self, shares: Optional[List[Union[float, int, torch.Tensor]]] = None, config: Config = Config(encoder_base=2, encoder_precision=16), session_uuid: Optional[UUID] = None, ring_size: int = 2**64, ): """Initialize ReplicatedSharedTensor. Args: shares (Optional[List[Union[float, int, torch.Tensor]]]): Shares list from which RSTensor is created. config (Config): The configuration where we keep the encoder precision and base. session_uuid (Optional[UUID]): Used to keep track of a share that is associated with a remote session ring_size (int): field used for the operations applied on the shares Defaults to 2**64 """ self.session_uuid = session_uuid self.ring_size = ring_size self.config = config self.fp_encoder = FixedPointEncoder(base=config.encoder_base, precision=config.encoder_precision) tensor_type = get_type_from_ring(ring_size) self.shares = [] if shares is not None: for i in range(len(shares)): self.shares.append(self._encode(shares[i]).to(tensor_type))
def reconstruct( self, decode: bool = True, get_shares: bool = False ) -> Union[torch.Tensor, List[torch.Tensor]]: """Reconstruct the secret. Request and get the shares from all the parties and reconstruct the secret. Depending on the value of "decode", the secret would be decoded or not using the FixedPrecision Encoder specific for the session. Args: decode (bool): True if decode using FixedPointEncoder. Defaults to True get_shares (bool): Retrieve only shares. Returns: torch.Tensor. The secret reconstructed. """ result = self.session.protocol.share_class.reconstruct( self.share_ptrs, get_shares=get_shares, security_type=self.session.protocol.security_type, ) if get_shares: return result if decode: fp_encoder = FixedPointEncoder( base=self.session.config.encoder_base, precision=self.session.config.encoder_precision, ) result = fp_encoder.decode(result) return result
def test_fp_precision_setter(): """ Test the precision setter for the FixedPointEncoder. """ fp_encoder = FixedPointEncoder() fp_encoder.precision = 3 assert fp_encoder.precision == 3 assert fp_encoder.scale == fp_encoder.base**3
def test_fp_base_setter(): """ Test the base setter for the FixedPointEncoder. """ fp_encoder = FixedPointEncoder() fp_encoder.base = 3 assert fp_encoder.base == 3 assert fp_encoder.scale == 3**fp_encoder.precision
def reconstruct( self, decode: bool = True, get_shares: bool = False ) -> Union[torch.Tensor, List[torch.Tensor]]: """Request and get the shares from all the parties and reconstruct the secret. Depending on the value of "decode", the secret would be decoded or not using the FixedPrecision Encoder specific for the session. Args: decode (bool): True if decode using FixedPointEncoder. Defaults to True get_shares (boot): True if get shares. Defaults to False. Returns: torch.Tensor. The secret reconstructed. """ def _request_and_get(share_ptr: ShareTensor) -> ShareTensor: """Function used to request and get a share - Duet Setup Args: share_ptr (ShareTensor): a ShareTensor Returns: ShareTensor. The ShareTensor in local. """ if not islocal(share_ptr): share_ptr.request(name="reconstruct", block=True) res = share_ptr.get_copy() return res request = _request_and_get request_wrap = parallel_execution(request) args = [[share] for share in self.share_ptrs] local_shares = request_wrap(args) tensor_type = self.session.tensor_type shares = [share.tensor for share in local_shares] if get_shares: return shares plaintext = sum(shares) if decode: fp_encoder = FixedPointEncoder( base=self.session.config.encoder_base, precision=self.session.config.encoder_precision, ) plaintext = fp_encoder.decode(plaintext) return plaintext
def test_fixed_point(precision, base) -> None: x = torch.tensor([1.25, 3.301]) shares = [x, x] rst = ReplicatedSharedTensor(shares=shares, config=Config(encoder_precision=precision, encoder_base=base)) fp_encoder = FixedPointEncoder(precision=precision, base=base) tensor_type = get_type_from_ring(rst.ring_size) for i in range(len(shares)): shares[i] = fp_encoder.encode(shares[i]).to(tensor_type) assert (torch.cat(shares) == torch.cat(rst.shares)).all() for i in range(len(shares)): shares[i] = fp_encoder.decode(shares[i].type(torch.LongTensor)) assert (torch.cat(shares) == torch.cat(rst.decode())).all()
def lt(self, other: "MPCTensor") -> "MPCTensor": """Lower than operator. Args: other (MPCTensor): MPCTensor to compare. Returns: MPCTensor: Result of the comparison. """ protocol = self.session.get_protocol() other = self.__check_or_convert(other, self.session) fp_encoder = FixedPointEncoder( base=self.session.config.encoder_base, precision=self.session.config.encoder_precision, ) one = fp_encoder.decode(1) return protocol.le(self + one, other)
def test_fp_encoder_init(): """Test correct FixedPointEncoder initialisation.""" fp_encoder = FixedPointEncoder( base=3, precision=8, ) assert fp_encoder.base == 3 assert fp_encoder.precision == 8 assert fp_encoder.scale == 3**8
def test_fp_encoding(): """Test correct encoding with FixedPointEncoder.""" fp_encoder = FixedPointEncoder() # Test encoding with tensors. tensor = torch.Tensor([1, 2, 3]) encoded_tensor = fp_encoder.encode(tensor) target_tensor = torch.LongTensor([1, 2, 3]) * fp_encoder.scale assert (encoded_tensor == target_tensor).all() # Test encoding with foats. test_float = 42.0 encoded_float = fp_encoder.encode(test_float) target_float = torch.LongTensor([42]) * fp_encoder.scale assert (encoded_float == target_float).all() # Test encoding with ints. test_int = 2 encoded_int = fp_encoder.encode(test_int) target_int = torch.LongTensor([2]) * fp_encoder.scale assert (encoded_int == target_int).all()
def test_truncation_algorithm1(get_clients, base, precision) -> None: parties = get_clients(3) falcon = Falcon("semi-honest") config = Config(encoder_base=base, encoder_precision=precision) session = Session(parties=parties, protocol=falcon, config=config) SessionManager.setup_mpc(session) x = torch.tensor([[1.24, 4.51, 6.87], [7.87, 1301, 541]]) x_mpc = MPCTensor(secret=x, session=session) result = ABY3.truncate(x_mpc, session, session.ring_size, session.config) fp_encoder = FixedPointEncoder(base=session.config.encoder_base, precision=session.config.encoder_precision) expected_res = x_mpc.reconstruct(decode=False) // fp_encoder.scale expected_res = fp_encoder.decode(expected_res) assert np.allclose(result.reconstruct(), expected_res, atol=1e-3)
def test_fp_decoding(): """Test correct decoding with FixedPointEncoder.""" fp_encoder = FixedPointEncoder() # Test decoding with tensors. # CaseA: throw a ValueError with floating point tensors. tensor = torch.Tensor([1.00, 2.00, 3.00]) with pytest.raises(ValueError): fp_encoder.decode(tensor) # Test decoding with ints. # Case A (precision = 0): fp_encoder = FixedPointEncoder(precision=0) test_int = 2 decoded_int = fp_encoder.decode(test_int) target_int = torch.LongTensor([2]) assert (decoded_int == target_int).all() # Case B (precision != 0): fp_encoder = FixedPointEncoder() test_int = 2 * fp_encoder.base**fp_encoder.precision decoded_int = fp_encoder.decode(test_int) target_int = torch.LongTensor([2]) assert (decoded_int == target_int).all()
def reconstruct(self, decode: bool = True) -> torch.Tensor: """Request and get the shares from all the parties and reconstruct the secret. Depending on the value of "decode", the secret would be decoded or not using the FixedPrecision Encoder specific for the session :return: the secret reconstructed :rtype: tensor """ def _request_and_get(share_ptr: ShareTensor) -> ShareTensor: """Function used to request and get a share - Duet Setup :return: the ShareTensor (local) :rtype: ShareTensor """ if not islocal(share_ptr): share_ptr.request(name="reconstruct", block=True) res = share_ptr.get_copy() return res request = _request_and_get request_wrap = parallel_execution(request) args = [[share] for share in self.share_ptrs] local_shares = request_wrap(args) tensor_type = self.session.tensor_type plaintext = sum(share.tensor for share in local_shares) if decode: fp_encoder = FixedPointEncoder( base=self.session.config.encoder_base, precision=self.session.config.encoder_precision, ) plaintext = fp_encoder.decode(plaintext) return plaintext
def test_private_compare(get_clients, security) -> None: parties = get_clients(3) falcon = Falcon(security_type=security) session = Session(parties=parties, protocol=falcon) SessionManager.setup_mpc(session) base = session.config.encoder_base precision = session.config.encoder_precision fp_encoder = FixedPointEncoder(base=base, precision=precision) secret = torch.tensor([[358.85, 79.29], [67.78, 2415.50]]) r = torch.tensor([[357.05, 90], [145.32, 2400.54]]) r = fp_encoder.encode(r) x = MPCTensor(secret=secret, session=session) x_b = ABY3.bit_decomposition_ttp(x, session) # bit shares x_p = [] # prime ring shares for share in x_b: x_p.append(ABY3.bit_injection(share, session, PRIME_NUMBER)) tensor_type = get_type_from_ring(session.ring_size) result = Falcon.private_compare(x_p, r.type(tensor_type)) expected_res = torch.tensor([[1, 0], [0, 1]], dtype=torch.bool) assert (result.reconstruct(decode=False) == expected_res).all()
def __init__( self, data: Optional[Union[float, int, torch.Tensor]] = None, session: Optional[Session] = None, encoder_base: int = 2, encoder_precision: int = 16, ring_size: int = 2**64, ) -> None: """Initialize ShareTensor. Args: data (Optional[Any]): The share a party holds. Defaults to None session (Optional[Any]): The session from which this shares belongs to. Defaults to None. encoder_base (int): The base for the encoder. Defaults to 2. encoder_precision (int): the precision for the encoder. Defaults to 16. ring_size (int): field used for the operations applied on the shares Defaults to 2**64 """ if session is None: self.session = Session(ring_size=ring_size, ) self.session.config.encoder_precision = encoder_precision self.session.config.encoder_base = encoder_base else: self.session = session encoder_precision = self.session.config.encoder_precision encoder_base = self.session.config.encoder_base # TODO: It looks like the same logic as above self.fp_encoder = FixedPointEncoder(base=encoder_base, precision=encoder_precision) self.tensor: Optional[torch.Tensor] = None if data is not None: tensor_type = self.session.tensor_type self.tensor = self._encode(data).type(tensor_type)
def test_fp_string_representation(): """Test the string representation of the FixedPointEncoder.""" fp_encoder = FixedPointEncoder() assert str(fp_encoder) == "[FixedPointEncoder]: precision: 16, base: 2"
class ReplicatedSharedTensor(metaclass=SyMPCTensor): """RSTensor is used when a party holds more than a single share,required by various protocols. Arguments: shares (Optional[List[Union[float, int, torch.Tensor]]]): Shares list from which RSTensor is created. Attributes: shares: The shares held by the party """ AUTOGRAD_IS_ON: bool = True # Used by the SyMPCTensor metaclass METHODS_FORWARD = {"numel", "t", "unsqueeze", "view", "sum", "clone"} PROPERTIES_FORWARD = {"T"} def __init__( self, shares: Optional[List[Union[float, int, torch.Tensor]]] = None, config: Config = Config(encoder_base=2, encoder_precision=16), session_uuid: Optional[UUID] = None, ring_size: int = 2**64, ): """Initialize ReplicatedSharedTensor. Args: shares (Optional[List[Union[float, int, torch.Tensor]]]): Shares list from which RSTensor is created. config (Config): The configuration where we keep the encoder precision and base. session_uuid (Optional[UUID]): Used to keep track of a share that is associated with a remote session ring_size (int): field used for the operations applied on the shares Defaults to 2**64 """ self.session_uuid = session_uuid self.ring_size = ring_size self.config = config self.fp_encoder = FixedPointEncoder(base=config.encoder_base, precision=config.encoder_precision) tensor_type = get_type_from_ring(ring_size) self.shares = [] if shares is not None: for i in range(len(shares)): self.shares.append(self._encode(shares[i]).to(tensor_type)) def _encode(self, data): return self.fp_encoder.encode(data) def decode(self): """Decode via FixedPointEncoder. Returns: List[torch.Tensor]: Decoded values """ return self._decode() def _decode(self): """Decodes shares list of RSTensor via FixedPointEncoder. Returns: List[torch.Tensor]: Decoded values """ shares = [] for i in range(len(self.shares)): tensor = self.fp_encoder.decode(self.shares[i].type( torch.LongTensor)) shares.append(tensor) return shares def add(self, y): """Apply the "add" operation between "self" and "y". Args: y: self+y """ def sub(self, y): """Apply the "sub" operation between "self" and "y". Args: y: self-y """ def rsub(self, y): """Apply the "sub" operation between "y" and "self". Args: y: self-y """ def mul(self, y): """Apply the "mul" operation between "self" and "y". Args: y: self*y """ def truediv(self, y): """Apply the "div" operation between "self" and "y". Args: y: self/y """ def matmul(self, y): """Apply the "matmul" operation between "self" and "y". Args: y: self@y """ def rmatmul(self, y): """Apply the "rmatmul" operation between "y" and "self". Args: y: self@y """ def xor(self, y): """Apply the "xor" operation between "self" and "y". Args: y: self^y """ def lt(self, y): """Lower than operator. Args: y: self<y """ def gt(self, y): """Greater than operator. Args: y: self>y """ def eq(self, y): """Equal operator. Args: y: self==y """ def ne(self, y): """Not Equal operator. Args: y: self!=y """ @staticmethod def hook_property(property_name: str) -> Any: """Hook a framework property (only getter). Ex: * if we call "shape" we want to call it on the underlying tensor and return the result * if we call "T" we want to call it on the underlying tensor but we want to wrap it in the same tensor type Args: property_name (str): property to hook Returns: A hooked property """ def property_new_rs_tensor_getter( _self: "ReplicatedSharedTensor") -> Any: shares = [] for i in range(len(_self.shares)): tensor = getattr(_self.shares[i], property_name) shares.append(tensor) res = ReplicatedSharedTensor(session_uuid=_self.session_uuid, config=_self.config) res.shares = shares return res def property_getter(_self: "ReplicatedSharedTensor") -> Any: prop = getattr(_self.shares[0], property_name) return prop if property_name in PROPERTIES_NEW_RS_TENSOR: res = property(property_new_rs_tensor_getter, None) else: res = property(property_getter, None) return res @staticmethod def hook_method(method_name: str) -> Callable[..., Any]: """Hook a framework method such that we know how to treat it given that we call it. Ex: * if we call "numel" we want to call it on the underlying tensor and return the result * if we call "unsqueeze" we want to call it on the underlying tensor but we want to wrap it in the same tensor type Args: method_name (str): method to hook Returns: A hooked method """ def method_new_rs_tensor(_self: "ReplicatedSharedTensor", *args: List[Any], **kwargs: Dict[Any, Any]) -> Any: shares = [] for i in range(len(_self.shares)): tensor = getattr(_self.shares[i], method_name)(*args, **kwargs) shares.append(tensor) res = ReplicatedSharedTensor(session_uuid=_self.session_uuid, config=_self.config) res.shares = shares return res def method(_self: "ReplicatedSharedTensor", *args: List[Any], **kwargs: Dict[Any, Any]) -> Any: method = getattr(_self.shares[0], method_name) res = method(*args, **kwargs) return res if method_name in METHODS_NEW_RS_TENSOR: res = method_new_rs_tensor else: res = method return res __add__ = add __radd__ = add __sub__ = sub __rsub__ = rsub __mul__ = mul __rmul__ = mul __matmul__ = matmul __rmatmul__ = rmatmul __truediv__ = truediv __xor__ = xor
class ShareTensor(metaclass=SyMPCTensor): """Single Share representation. Arguments: data (Optional[Any]): the share a party holds session (Optional[Any]): the session from which this shares belongs to encoder_base (int): the base for the encoder encoder_precision (int): the precision for the encoder ring_size (int): field used for the operations applied on the shares Attributes: Syft Serializable Attributes id (UID): the id to store the session tags (Optional[List[str]): an optional list of strings that are tags used at search description (Optional[str]): an optional string used to describe the session tensor (Any): the value of the share session_uuid (Optional[UUID]): keep track from which session the share belongs to encoder_precision (int): precision for the encoder encoder_base (int): base for the encoder """ __slots__ = { # Populated in Syft "id", "tags", "description", "tensor", "session_uuid", "config", "fp_encoder", "ring_size", } # Used by the SyMPCTensor metaclass METHODS_FORWARD: Set[str] = { "numel", "squeeze", "unsqueeze", "t", "view", "expand", "sum", "clone", "flatten", "reshape", "repeat", "narrow", "dim", "transpose", "roll", } PROPERTIES_FORWARD: Set[str] = {"T", "shape"} def __init__( self, data: Optional[Union[float, int, torch.Tensor]] = None, config: Config = Config(encoder_base=2, encoder_precision=16), session_uuid: Optional[UUID] = None, ring_size: int = 2**64, ) -> None: """Initialize ShareTensor. Args: data (Optional[Any]): The share a party holds. Defaults to None config (Config): The configuration where we keep the encoder precision and base. session_uuid (Optional[UUID]): Used to keep track of a share that is associated with a remote session ring_size (int): field used for the operations applied on the shares Defaults to 2**64 """ self.session_uuid = session_uuid self.ring_size = ring_size self.config = config self.fp_encoder = FixedPointEncoder(base=config.encoder_base, precision=config.encoder_precision) self.tensor: Optional[torch.Tensor] = None if data is not None: tensor_type = get_type_from_ring(ring_size) self.tensor = self._encode(data).to(tensor_type) def _encode(self, data): return self.fp_encoder.encode(data) def decode(self): """Decode via FixedPrecisionEncoder. Returns: torch.Tensor: Decoded value """ return self._decode() def _decode(self): return self.fp_encoder.decode(self.tensor.type(torch.LongTensor)) @staticmethod def sanity_checks(x: "ShareTensor", y: Union[int, float, torch.Tensor, "ShareTensor"], op_str: str) -> "ShareTensor": """Check the type of "y" and covert it to share if necessary. Args: x (ShareTensor): Typically "self". y (Union[int, float, torch.Tensor, "ShareTensor"]): Tensor to check. op_str (str): String operator. Returns: ShareTensor: the converted y value. Raises: ValueError: if both values are shares and they have different uuids """ if not isinstance(y, ShareTensor): if x.session_uuid is not None: session = sympc.session.get_session(str(x.session_uuid)) ring_size = session.ring_size config = session.config else: ring_size = x.ring_size config = x.config y = ShareTensor(data=y, ring_size=ring_size, config=config) elif y.session_uuid and x.session_uuid and y.session_uuid != x.session_uuid: raise ValueError( f"Session UUIDs did not match {x.session_uuid} {y.session_uuid}" ) return y def apply_function(self, y: Union["ShareTensor", torch.Tensor, int, float], op_str: str) -> "ShareTensor": """Apply a given operation. Args: y (Union["ShareTensor", torch.Tensor, int, float]): tensor to apply the operator. op_str (str): Operator. Returns: ShareTensor: Result of the operation. """ op = getattr(operator, op_str) if isinstance(y, ShareTensor): value = op(self.tensor, y.tensor) else: value = op(self.tensor, y) session_uuid = self.session_uuid or y.session_uuid if session_uuid is not None: session = sympc.session.get_session(str(session_uuid)) ring_size = session.ring_size config = session.config else: # Use the values from "self" ring_size = self.ring_size config = self.config res = ShareTensor(ring_size=ring_size, session_uuid=session_uuid, config=config) res.tensor = value return res def add(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "add" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): self + y Returns: ShareTensor. Result of the operation. """ y_share = ShareTensor.sanity_checks(self, y, "add") res = self.apply_function(y_share, "add") return res def sub(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "sub" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): self - y Returns: ShareTensor. Result of the operation. """ y_share = ShareTensor.sanity_checks(self, y, "sub") res = self.apply_function(y_share, "sub") return res def rsub( self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "sub" operation between "y" and "self". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): y - self Returns: ShareTensor. Result of the operation. """ y_share = ShareTensor.sanity_checks(self, y, "sub") res = y_share.apply_function(self, "sub") return res def mul(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "mul" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): self * y Returns: ShareTensor. Result of the operation. """ y = ShareTensor.sanity_checks(self, y, "mul") res = self.apply_function(y, "mul") if self.session_uuid is None and y.session_uuid is None: # We are using a simple share without usig the MPCTensor # In case we used the MPCTensor - the division would have # been done in the protocol res.tensor //= self.fp_encoder.scale return res def xor(self, y: Union[int, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "xor" operation between "self" and "y". Args: y (Union[int, torch.Tensor, "ShareTensor"]): self xor y Returns: ShareTensor: Result of the operation. """ res = self.apply_function(y, "xor") return res def matmul( self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "matmul" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): self @ y. Returns: ShareTensor: Result of the operation. """ y = ShareTensor.sanity_checks(self, y, "matmul") res = self.apply_function(y, "matmul") if self.session_uuid is None and y.session_uuid is None: # We are using a simple share without usig the MPCTensor # In case we used the MPCTensor - the division would have # been done in the protocol res.tensor = res.tensor // self.fp_encoder.scale return res def rmatmul(self, y: torch.Tensor) -> "ShareTensor": """Apply the "rmatmul" operation between "y" and "self". Args: y (torch.Tensor): y @ self Returns: ShareTensor. Result of the operation. """ y = ShareTensor.sanity_checks(self, y, "matmul") return y.matmul(self) def div(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "div" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): Denominator. Returns: ShareTensor: Result of the operation. Raises: ValueError: If y is not an integer or LongTensor. """ if not isinstance(y, (int, torch.LongTensor)): raise ValueError("Div works (for the moment) only with integers!") res = ShareTensor(session_uuid=self.session_uuid, config=self.config) # res = self.apply_function(y, "floordiv") res.tensor = self.tensor // y return res def __gt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool: """Greater than operator. Args: y (Union["ShareTensor", torch.Tensor, int]): Tensor to compare. Returns: bool: Result of the comparison. """ y_share = ShareTensor.sanity_checks(self, y, "gt") res = self.tensor > y_share.tensor return res def __lt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool: """Lower than operator. Args: y (Union["ShareTensor", torch.Tensor, int]): Tensor to compare. Returns: bool: Result of the comparison. """ y_share = ShareTensor.sanity_checks(self, y, "lt") res = self.tensor < y_share.tensor return res def __str__(self) -> str: """Representation. Returns: str: Return the string representation of ShareTensor. """ type_name = type(self).__name__ out = f"[{type_name}]" out = f"{out}\n\t| Session UUID: {self.session_uuid}" out = f"{out}\n\t| {self.fp_encoder}" out = f"{out}\n\t| Data: {self.tensor}" return out def __repr__(self) -> str: """Representation. Returns: String representation. """ return self.__str__() def __eq__(self, other: Any) -> bool: """Equal operator. Check if "self" is equal with another object given a set of attributes to compare. Args: other (Any): Tensor to compare. Returns: bool: True if equal False if not. """ if not (self.tensor == other.tensor).all(): return False if not self.config == other.config: return False if (self.session_uuid and other.session_uuid and self.session_uuid != other.session_uuid): # If both shares have a session_uuid we consider them not equal # else they are return False return True @staticmethod def hook_property(property_name: str) -> Any: """Hook a framework property (only getter). Ex: * if we call "shape" we want to call it on the underlying tensor and return the result * if we call "T" we want to call it on the underlying tensor but we want to wrap it in the same tensor type Args: property_name (str): property to hook Returns: A hooked property """ def property_new_share_tensor_getter(_self: "ShareTensor") -> Any: tensor = getattr(_self.tensor, property_name) res = ShareTensor(session_uuid=_self.session_uuid, config=_self.config) res.tensor = tensor return res def property_getter(_self: "ShareTensor") -> Any: prop = getattr(_self.tensor, property_name) return prop if property_name in PROPERTIES_NEW_SHARE_TENSOR: res = property(property_new_share_tensor_getter, None) else: res = property(property_getter, None) return res @staticmethod def hook_method(method_name: str) -> Callable[..., Any]: """Hook a framework method such that we know how to treat it given that we call it. Ex: * if we call "numel" we want to call it on the underlying tensor and return the result * if we call "unsqueeze" we want to call it on the underlying tensor but we want to wrap it in the same tensor type Args: method_name (str): method to hook Returns: A hooked method """ def method_new_share_tensor(_self: "ShareTensor", *args: List[Any], **kwargs: Dict[Any, Any]) -> Any: method = getattr(_self.tensor, method_name) tensor = method(*args, **kwargs) res = ShareTensor(session_uuid=_self.session_uuid, config=_self.config) res.tensor = tensor return res def method(_self: "ShareTensor", *args: List[Any], **kwargs: Dict[Any, Any]) -> Any: method = getattr(_self.tensor, method_name) res = method(*args, **kwargs) return res if method_name in METHODS_NEW_SHARE_TENSOR: res = method_new_share_tensor else: res = method return res @staticmethod def reconstruct( share_ptrs: List["ShareTensor"], get_shares=False, security_type: str = "semi-honest", ) -> torch.Tensor: """Reconstruct original value from shares. Args: share_ptrs (List[ShareTensor]): List of sharetensors. get_shares (boolean): retrieve shares or reconstructed value. security_type (str): Type of security by protocol. Returns: plaintext/shares (torch.Tensor/List[torch.Tensors]): Plaintext or list of shares. """ def _request_and_get(share_ptr: ShareTensor) -> ShareTensor: """Function used to request and get a share - Duet Setup. Args: share_ptr (ShareTensor): a ShareTensor Returns: ShareTensor. The ShareTensor in local. """ if not islocal(share_ptr): share_ptr.request(block=True) res = share_ptr.get_copy() return res request = _request_and_get request_wrap = parallel_execution(request) args = [[share] for share in share_ptrs] local_shares = request_wrap(args) shares = [share.tensor for share in local_shares] if get_shares: return shares plaintext = sum(shares) return plaintext @staticmethod def distribute_shares(shares: List["ShareTensor"], session: Session): """Distribute a list of shares. Args: shares (List[ShareTensor): list of shares to distribute. session (Session): Session for which those shares were generated Returns: List of ShareTensorPointers. """ rank_to_uuid = session.rank_to_uuid parties = session.parties share_ptrs = [] for rank, share in enumerate(shares): share.session_uuid = rank_to_uuid[rank] party = parties[rank] share_ptrs.append(share.send(party)) return share_ptrs __add__ = add __radd__ = add __sub__ = sub __rsub__ = rsub __mul__ = mul __rmul__ = mul __matmul__ = matmul __rmatmul__ = rmatmul __truediv__ = div __xor__ = xor
class ShareTensor: """This class represents 1 share that a party holds when doing secret sharing. Arguments: data (Optional[Any]): the share a party holds session (Optional[Any]): the session from which this shares belongs to encoder_base (int): the base for the encoder encoder_precision (int): the precision for the encoder ring_size (int): field used for the operations applied on the shares Attributes: Syft Serializable Attributes id (UID): the id to store the session tags (Optional[List[str]): an optional list of strings that are tags used at search description (Optional[str]): an optional string used to describe the session tensor (Any): the value of the share session (Session): keep track from which session this share belongs to fp_encoder (FixedPointEncoder): the encoder used to convert a share from/to fixed point """ __slots__ = { # Populated in Syft "id", "tags", "description", "tensor", "session", "fp_encoder", } def __init__( self, data: Optional[Union[float, int, torch.Tensor]] = None, session: Optional[Session] = None, encoder_base: int = 2, encoder_precision: int = 16, ring_size: int = 2**64, ) -> None: """Initializer for the ShareTensor.""" if session is None: self.session = Session(ring_size=ring_size, ) self.session.config.encoder_precision = encoder_precision self.session.config.encoder_base = encoder_base else: self.session = session encoder_precision = self.session.config.encoder_precision encoder_base = self.session.config.encoder_base # TODO: It looks like the same logic as above self.fp_encoder = FixedPointEncoder(base=encoder_base, precision=encoder_precision) self.tensor: Optional[torch.Tensor] = None if data is not None: tensor_type = self.session.tensor_type self.tensor = self._encode(data).type(tensor_type) def _encode(self, data): return self.fp_encoder.encode(data) def decode(self): return self._decode() def _decode(self): return self.fp_encoder.decode(self.tensor.type(torch.LongTensor)) @staticmethod def sanity_checks(x: "ShareTensor", y: Union[int, float, torch.Tensor, "ShareTensor"], op_str: str) -> "ShareTensor": """Check the type of "y" and convert it to a share if necessary. :return: the y value :rtype: ShareTensor, int or Integer type Tensor """ if not isinstance(y, ShareTensor): y = ShareTensor(data=y, session=x.session) return y def apply_function(self, y: Union["ShareTensor", torch.Tensor, int, float], op_str: str) -> "ShareTensor": """Apply a given operation. :return: the result of applying "op_str" on "self" and y :rtype: ShareTensor """ op = getattr(operator, op_str) if isinstance(y, ShareTensor): value = op(self.tensor, y.tensor) else: value = op(self.tensor, y) res = ShareTensor(session=self.session) res.tensor = value return res def add(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "add" operation between "self" and "y". :return: self + y :rtype: ShareTensor """ y_share = ShareTensor.sanity_checks(self, y, "add") res = self.apply_function(y_share, "add") return res def sub(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "sub" operation between "self" and "y". :return: self - y :rtype: ShareTensor """ y_share = ShareTensor.sanity_checks(self, y, "sub") res = self.apply_function(y_share, "sub") return res def rsub( self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "sub" operation between "y" and "self". :return: y - self :rtype: ShareTensor """ y_share = ShareTensor.sanity_checks(self, y, "sub") res = y_share.apply_function(self, "sub") return res def mul(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "mul" operation between "self" and "y". :return: self * y :rtype: ShareTensor """ y = ShareTensor.sanity_checks(self, y, "mul") res = self.apply_function(y, "mul") if self.session.nr_parties == 0: # We are using a simple share without usig the MPCTensor # In case we used the MPCTensor - the division would have # been done in the protocol res.tensor = res.tensor // self.fp_encoder.scale return res def matmul( self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "matmul" operation between "self" and "y". :return: self @ y :rtype: ShareTensor """ y = ShareTensor.sanity_checks(self, y, "matmul") res = self.apply_function(y, "matmul") if self.session.nr_parties == 0: # We are using a simple share without usig the MPCTensor # In case we used the MPCTensor - the division would have # been done in the protocol res.tensor = res.tensor // self.fp_encoder.scale return res def rmatmul(self, y: torch.Tensor) -> "ShareTensor": """Apply the reversed "matmul" operation between "self" and "y". :return: y @ self :rtype: ShareTensor """ y = ShareTensor.sanity_checks(self, y, "matmul") return y.matmul(self) def div(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "div" operation between "self" and "y". Currently, NotImplemented. :return: self / y :rtype: ShareTensor """ if not isinstance(y, (int, torch.LongTensor)): raise ValueError("Div works (for the moment) only with integers!") res = ShareTensor(session=self.session) res.tensor = self.tensor // y return res def __getattr__(self, attr_name: str) -> Any: """Get the attribute from the ShareTensor. If the attribute is not found at the ShareTensor level, the it would look for in the the "tensor". :return: the attribute value :rtype: Anything """ # Default to some tensor specific attributes like # size, shape, etc. tensor = self.tensor return getattr(tensor, attr_name) def __gt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool: """Check if "self" is bigger than "y". :return: self > y :rtype: bool """ y_share = ShareTensor.sanity_checks(self, y, "gt") res = self.tensor > y_share.tensor return res def __lt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool: """Check if "self" is less than "y". :return: self < y :rtype: bool """ y_share = ShareTensor.sanity_checks(self, y, "lt") res = self.tensor < y_share.tensor return res def __str__(self) -> str: """Return the string representation of ShareTensor.""" type_name = type(self).__name__ out = f"[{type_name}]" out = f"{out}\n\t| {self.fp_encoder}" out = f"{out}\n\t| Data: {self.tensor}" return out def __repr__(self): return self.__str__() def __eq__(self, other: Any) -> bool: """Check if "self" is equal with another object given a set of attributes to compare. :return: if self and other are equal :rtype: bool """ if not (self.tensor == other.tensor).all(): return False if not (self.session == other.session): return False return True __add__ = add __radd__ = add __sub__ = sub __rsub__ = rsub __mul__ = mul __rmul__ = mul __matmul__ = matmul __rmatmul__ = rmatmul __truediv__ = div
class ShareTensor(metaclass=SyMPCTensor): """Single Share representation. Arguments: data (Optional[Any]): the share a party holds session (Optional[Any]): the session from which this shares belongs to encoder_base (int): the base for the encoder encoder_precision (int): the precision for the encoder ring_size (int): field used for the operations applied on the shares Attributes: Syft Serializable Attributes id (UID): the id to store the session tags (Optional[List[str]): an optional list of strings that are tags used at search description (Optional[str]): an optional string used to describe the session tensor (Any): the value of the share session (Session): keep track from which session this share belongs to fp_encoder (FixedPointEncoder): the encoder used to convert a share from/to fixed point """ __slots__ = { # Populated in Syft "id", "tags", "description", "tensor", "session", "fp_encoder", } # Used by the SyMPCTensor metaclass METHODS_FORWARD: Set[str] = { "numel", "unsqueeze", "t", "view", "sum", "clone", "flatten", "reshape", } PROPERTIES_FORWARD: Set[str] = {"T", "shape"} def __init__( self, data: Optional[Union[float, int, torch.Tensor]] = None, session: Optional[Session] = None, encoder_base: int = 2, encoder_precision: int = 16, ring_size: int = 2**64, ) -> None: """Initialize ShareTensor. Args: data (Optional[Any]): The share a party holds. Defaults to None session (Optional[Any]): The session from which this shares belongs to. Defaults to None. encoder_base (int): The base for the encoder. Defaults to 2. encoder_precision (int): the precision for the encoder. Defaults to 16. ring_size (int): field used for the operations applied on the shares Defaults to 2**64 """ if session is None: self.session = Session(ring_size=ring_size, ) self.session.config.encoder_precision = encoder_precision self.session.config.encoder_base = encoder_base else: self.session = session encoder_precision = self.session.config.encoder_precision encoder_base = self.session.config.encoder_base # TODO: It looks like the same logic as above self.fp_encoder = FixedPointEncoder(base=encoder_base, precision=encoder_precision) self.tensor: Optional[torch.Tensor] = None if data is not None: tensor_type = self.session.tensor_type self.tensor = self._encode(data).type(tensor_type) def _encode(self, data): return self.fp_encoder.encode(data) def decode(self): """Decode via FixedPrecisionEncoder. Returns: torch.Tensor: Decoded value """ return self._decode() def _decode(self): return self.fp_encoder.decode(self.tensor.type(torch.LongTensor)) @staticmethod def sanity_checks(x: "ShareTensor", y: Union[int, float, torch.Tensor, "ShareTensor"], op_str: str) -> "ShareTensor": """Check the type of "y" and covert it to share if necessary. Args: x (ShareTensor): Typically "self". y (Union[int, float, torch.Tensor, "ShareTensor"]): Tensor to check. op_str (str): String operator. Returns: ShareTensor: the converted y value. """ if not isinstance(y, ShareTensor): y = ShareTensor(data=y, session=x.session) return y def apply_function(self, y: Union["ShareTensor", torch.Tensor, int, float], op_str: str) -> "ShareTensor": """Apply a given operation. Args: y (Union["ShareTensor", torch.Tensor, int, float]): tensor to apply the operator. op_str (str): Operator. Returns: ShareTensor: Result of the operation. """ op = getattr(operator, op_str) if isinstance(y, ShareTensor): value = op(self.tensor, y.tensor) else: value = op(self.tensor, y) res = ShareTensor(session=self.session) res.tensor = value return res def add(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "add" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): self + y Returns: ShareTensor. Result of the operation. """ y_share = ShareTensor.sanity_checks(self, y, "add") res = self.apply_function(y_share, "add") return res def sub(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "sub" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): self - y Returns: ShareTensor. Result of the operation. """ y_share = ShareTensor.sanity_checks(self, y, "sub") res = self.apply_function(y_share, "sub") return res def rsub( self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "sub" operation between "y" and "self". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): y - self Returns: ShareTensor. Result of the operation. """ y_share = ShareTensor.sanity_checks(self, y, "sub") res = y_share.apply_function(self, "sub") return res def mul(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "mul" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): self * y Returns: ShareTensor. Result of the operation. """ y = ShareTensor.sanity_checks(self, y, "mul") res = self.apply_function(y, "mul") if self.session.nr_parties == 0: # We are using a simple share without usig the MPCTensor # In case we used the MPCTensor - the division would have # been done in the protocol res.tensor //= self.fp_encoder.scale return res def xor(self, y: Union[int, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "xor" operation between "self" and "y". Args: y (Union[int, torch.Tensor, "ShareTensor"]): self xor y Returns: ShareTensor: Result of the operation. """ res = ShareTensor(session=self.session) if isinstance(y, ShareTensor): res.tensor = self.tensor ^ y.tensor else: res.tensor = self.tensor ^ y return res def matmul( self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "matmul" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): self @ y. Returns: ShareTensor: Result of the operation. """ y = ShareTensor.sanity_checks(self, y, "matmul") res = self.apply_function(y, "matmul") if self.session.nr_parties == 0: # We are using a simple share without usig the MPCTensor # In case we used the MPCTensor - the division would have # been done in the protocol res.tensor = res.tensor // self.fp_encoder.scale return res def rmatmul(self, y: torch.Tensor) -> "ShareTensor": """Apply the "rmatmul" operation between "y" and "self". Args: y (torch.Tensor): y @ self Returns: ShareTensor. Result of the operation. """ y = ShareTensor.sanity_checks(self, y, "matmul") return y.matmul(self) def div(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "div" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): Denominator. Returns: ShareTensor: Result of the operation. Raises: ValueError: If y is not an integer or LongTensor. """ if not isinstance(y, (int, torch.LongTensor)): raise ValueError("Div works (for the moment) only with integers!") res = ShareTensor(session=self.session) res.tensor = self.tensor // y return res def __gt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool: """Greater than operator. Args: y (Union["ShareTensor", torch.Tensor, int]): Tensor to compare. Returns: bool: Result of the comparison. """ y_share = ShareTensor.sanity_checks(self, y, "gt") res = self.tensor > y_share.tensor return res def __lt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool: """Lower than operator. Args: y (Union["ShareTensor", torch.Tensor, int]): Tensor to compare. Returns: bool: Result of the comparison. """ y_share = ShareTensor.sanity_checks(self, y, "lt") res = self.tensor < y_share.tensor return res def __str__(self) -> str: """Representation. Returns: str: Return the string representation of ShareTensor. """ type_name = type(self).__name__ out = f"[{type_name}]" out = f"{out}\n\t| {self.fp_encoder}" out = f"{out}\n\t| Data: {self.tensor}" return out def __repr__(self) -> str: """Representation. Returns: String representation. """ return self.__str__() def __eq__(self, other: Any) -> bool: """Equal operator. Check if "self" is equal with another object given a set of attributes to compare. Args: other (Any): Tensor to compare. Returns: bool: True if equal False if not. """ if not (self.tensor == other.tensor).all(): return False if not (self.session == other.session): return False return True @staticmethod def hook_property(property_name: str) -> Any: """Hook a framework property (only getter). Ex: * if we call "shape" we want to call it on the underlying tensor and return the result * if we call "T" we want to call it on the underlying tensor but we want to wrap it in the same tensor type Args: property_name (str): property to hook Returns: A hooked property """ def property_new_share_tensor_getter(_self: "ShareTensor") -> Any: tensor = getattr(_self.tensor, property_name) res = ShareTensor(session=_self.session) res.tensor = tensor return res def property_getter(_self: "ShareTensor") -> Any: prop = getattr(_self.tensor, property_name) return prop if property_name in PROPERTIES_NEW_SHARE_TENSOR: res = property(property_new_share_tensor_getter, None) else: res = property(property_getter, None) return res @staticmethod def hook_method(method_name: str) -> Callable[..., Any]: """Hook a framework method such that we know how to treat it given that we call it. Ex: * if we call "numel" we want to call it on the underlying tensor and return the result * if we call "unsqueeze" we want to call it on the underlying tensor but we want to wrap it in the same tensor type Args: method_name (str): method to hook Returns: A hooked method """ def method_new_share_tensor(_self: "ShareTensor", *args: List[Any], **kwargs: Dict[Any, Any]) -> Any: method = getattr(_self.tensor, method_name) tensor = method(*args, **kwargs) res = ShareTensor(session=_self.session) res.tensor = tensor return res def method(_self: "ShareTensor", *args: List[Any], **kwargs: Dict[Any, Any]) -> Any: method = getattr(_self.tensor, method_name) res = method(*args, **kwargs) return res if method_name in METHODS_NEW_SHARE_TENSOR: res = method_new_share_tensor else: res = method return res __add__ = add __radd__ = add __sub__ = sub __rsub__ = rsub __mul__ = mul __rmul__ = mul __matmul__ = matmul __rmatmul__ = rmatmul __truediv__ = div __xor__ = xor
class ReplicatedSharedTensor(metaclass=SyMPCTensor): """RSTensor is used when a party holds more than a single share,required by various protocols. Arguments: shares (Optional[List[Union[float, int, torch.Tensor]]]): Shares list from which RSTensor is created. Attributes: shares: The shares held by the party """ __slots__ = { # Populated in Syft "id", "tags", "description", "shares", "session_uuid", "config", "fp_encoder", "ring_size", } # Used by the SyMPCTensor metaclass METHODS_FORWARD = { "numel", "t", "unsqueeze", "view", "sum", "clone", "repeat" } PROPERTIES_FORWARD = {"T", "shape"} def __init__( self, shares: Optional[List[Union[float, int, torch.Tensor]]] = None, config: Config = Config(encoder_base=2, encoder_precision=16), session_uuid: Optional[UUID] = None, ring_size: int = 2**64, ): """Initialize ReplicatedSharedTensor. Args: shares (Optional[List[Union[float, int, torch.Tensor]]]): Shares list from which RSTensor is created. config (Config): The configuration where we keep the encoder precision and base. session_uuid (Optional[UUID]): Used to keep track of a share that is associated with a remote session ring_size (int): field used for the operations applied on the shares Defaults to 2**64 """ self.session_uuid = session_uuid self.ring_size = ring_size if ring_size in {2, PRIME_NUMBER}: self.config = Config(encoder_base=1, encoder_precision=0) else: self.config = config self.fp_encoder = FixedPointEncoder( base=self.config.encoder_base, precision=self.config.encoder_precision) tensor_type = get_type_from_ring(ring_size) self.shares = [] if shares is not None: self.shares = [ self._encode(share).to(tensor_type) for share in shares ] def _encode(self, data: torch.Tensor) -> torch.Tensor: """Encode via FixedPointEncoder. Args: data (torch.Tensor): Tensor to be encoded Returns: encoded_data (torch.Tensor): Decoded values """ return self.fp_encoder.encode(data) def decode(self) -> List[torch.Tensor]: """Decode via FixedPointEncoder. Returns: List[torch.Tensor]: Decoded values """ return self._decode() def _decode(self) -> List[torch.Tensor]: """Decodes shares list of RSTensor via FixedPointEncoder. Returns: List[torch.Tensor]: Decoded values """ shares = [] shares = [ self.fp_encoder.decode(share.type(torch.LongTensor)) for share in self.shares ] return shares def get_shares(self) -> List[torch.Tensor]: """Get shares. Returns: List[torch.Tensor]: List of shares. """ return self.shares def get_ring_size(self) -> str: """Ring size of tensor. Returns: ring_size (str): Returns ring_size of tensor in string. It is typecasted to string as we cannot serialize 2**64 """ return str(self.ring_size) def get_config(self) -> Dict: """Config of tensor. Returns: config (Dict): returns config of the tensor as dict. """ return dataclasses.asdict(self.config) @staticmethod def addmodprime(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Computes addition(x+y) modulo PRIME_NUMBER constant. Args: x (torch.Tensor): input tensor y (torch.tensor): input tensor Returns: value (torch.Tensor): Result of the operation. Raises: ValueError : If either of the tensors datatype is not torch.uint8 """ if x.dtype != torch.uint8 or y.dtype != torch.uint8: raise ValueError( f"Both tensors x:{x.dtype} y:{y.dtype} should be of torch.uint8 dtype" ) return (x + y) % PRIME_NUMBER @staticmethod def submodprime(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Computes subtraction(x-y) modulo PRIME_NUMBER constant. Args: x (torch.Tensor): input tensor y (torch.tensor): input tensor Returns: value (torch.Tensor): Result of the operation. Raises: ValueError : If either of the tensors datatype is not torch.uint8 """ if x.dtype != torch.uint8 or y.dtype != torch.uint8: raise ValueError( f"Both tensors x:{x.dtype} y:{y.dtype} should be of torch.uint8 dtype" ) # Typecasting is done, as underflow returns a positive number,as it is unsigned. x = x.to(torch.int8) y = y.to(torch.int8) result = (x - y) % PRIME_NUMBER return result.to(torch.uint8) @staticmethod def mulmodprime(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Computes multiplication(x*y) modulo PRIME_NUMBER constant. Args: x (torch.Tensor): input tensor y (torch.tensor): input tensor Returns: value (torch.Tensor): Result of the operation. Raises: ValueError : If either of the tensors datatype is not torch.uint8 """ if x.dtype != torch.uint8 or y.dtype != torch.uint8: raise ValueError( f"Both tensors x:{x.dtype} y:{y.dtype} should be of torch.uint8 dtype" ) # We typecast as multiplication result in 2n bits ,which causes overflow. x = x.to(torch.int16) y = y.to(torch.int16) result = (x * y) % PRIME_NUMBER return result.to(torch.uint8) @staticmethod def get_op(ring_size: int, op_str: str) -> Callable[..., Any]: """Returns method attribute based on ring_size and op_str. Args: ring_size (int): Ring size op_str (str): Operation string. Returns: op (Callable[...,Any]): The operation method for the op_str. Raises: ValueError : If invalid ring size is given as input. """ op = None if ring_size == 2: op = getattr(operator, BINARY_MAP[op_str]) elif ring_size == PRIME_NUMBER: op = getattr(ReplicatedSharedTensor, op_str + "modprime") elif ring_size in RING_SIZE_TO_TYPE.keys(): op = getattr(operator, op_str) else: raise ValueError(f"Invalid ring size: {ring_size}") return op @staticmethod def sanity_checks( x: "ReplicatedSharedTensor", y: Union[int, float, torch.Tensor, "ReplicatedSharedTensor"], ) -> "ReplicatedSharedTensor": """Check the type of "y" and convert it to share if necessary. Args: x (ReplicatedSharedTensor): Typically "self". y (Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]): Tensor to check. Returns: ReplicatedSharedTensor: the converted y value. Raises: ValueError: if both values are shares and they have different uuids ValueError: if both values have different number of shares. ValueError: if both RSTensor have different ring_sizes """ if not isinstance(y, ReplicatedSharedTensor): # As prime ring size is unsigned,we convert negative values. y = y % PRIME_NUMBER if x.ring_size == PRIME_NUMBER else y y = ReplicatedSharedTensor( session_uuid=x.session_uuid, shares=[y], ring_size=x.ring_size, config=x.config, ) elif y.session_uuid and x.session_uuid and y.session_uuid != x.session_uuid: raise ValueError( f"Session UUIDs did not match {x.session_uuid} {y.session_uuid}" ) elif len(x.shares) != len(y.shares): raise ValueError( f"Both RSTensors should have equal number of shares {len(x.shares)} {len(y.shares)}" ) elif x.ring_size != y.ring_size: raise ValueError( f"Both RSTensors should have same ring_size {x.ring_size} {y.ring_size}" ) session_uuid = x.session_uuid if session_uuid is not None: session = sympc.session.get_session(str(x.session_uuid)) else: session = Session(config=x.config, ring_size=x.ring_size) session.nr_parties = 1 return y, session def __apply_public_op(self, y: Union[torch.Tensor, float, int], op_str: str) -> "ReplicatedSharedTensor": """Apply an operation on "self" which is a RSTensor and a public value. Args: y (Union[torch.Tensor, float, int]): Tensor to apply the operation. op_str (str): The operation. Returns: ReplicatedSharedTensor: The operation "op_str" applied on "self" and "y" Raises: ValueError: If "op_str" is not supported. """ y, session = ReplicatedSharedTensor.sanity_checks(self, y) op = ReplicatedSharedTensor.get_op(self.ring_size, op_str) shares = copy.deepcopy(self.shares) if op_str in {"add", "sub"}: if session.rank != 1: idx = (session.nr_parties - session.rank) % session.nr_parties shares[idx] = op(shares[idx], y.shares[0]) else: raise ValueError(f"{op_str} not supported") result = ReplicatedSharedTensor( ring_size=self.ring_size, session_uuid=self.session_uuid, config=self.config, ) result.shares = shares return result def __apply_private_op(self, y: "ReplicatedSharedTensor", op_str: str) -> "ReplicatedSharedTensor": """Apply an operation on 2 RSTensors (secret shared values). Args: y (RelicatedSharedTensor): Tensor to apply the operation op_str (str): The operation Returns: ReplicatedSharedTensor: The operation "op_str" applied on "self" and "y" Raises: ValueError: If "op_str" not supported. """ y, session = ReplicatedSharedTensor.sanity_checks(self, y) op = ReplicatedSharedTensor.get_op(self.ring_size, op_str) shares = [] if op_str in {"add", "sub"}: for x_share, y_share in zip(self.shares, y.shares): shares.append(op(x_share, y_share)) else: raise ValueError(f"{op_str} not supported") result = ReplicatedSharedTensor( ring_size=self.ring_size, session_uuid=self.session_uuid, config=self.config, ) result.shares = shares return result def __apply_op( self, y: Union["ReplicatedSharedTensor", torch.Tensor, float, int], op_str: str, ) -> "ReplicatedSharedTensor": """Apply a given operation ". This function checks if "y" is private or public value. Args: y (Union[ReplicatedSharedTensor,torch.Tensor, float, int]): tensor to apply the operation. op_str (str): the operation. Returns: ReplicatedSharedTensor: the operation "op_str" applied on "self" and "y" """ is_private = isinstance(y, ReplicatedSharedTensor) if is_private: result = self.__apply_private_op(y, op_str) else: result = self.__apply_public_op(y, op_str) return result def add( self, y: Union[int, float, torch.Tensor, "ReplicatedSharedTensor"] ) -> "ReplicatedSharedTensor": """Apply the "add" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]): self + y Returns: ReplicatedSharedTensor: Result of the operation. """ return self.__apply_op(y, "add") def sub( self, y: Union[int, float, torch.Tensor, "ReplicatedSharedTensor"] ) -> "ReplicatedSharedTensor": """Apply the "sub" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]): self - y Returns: ReplicatedSharedTensor: Result of the operation. """ return self.__apply_op(y, "sub") def rsub( self, y: Union[int, float, torch.Tensor, "ReplicatedSharedTensor"] ) -> "ReplicatedSharedTensor": """Apply the "sub" operation between "y" and "self". Args: y (Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]): y -self Returns: ReplicatedSharedTensor: Result of the operation. """ return self.__apply_op(y, "sub") def mul( self, y: Union[int, float, torch.Tensor, "ReplicatedSharedTensor"] ) -> "ReplicatedSharedTensor": """Apply the "mul" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]): self*y Returns: ReplicatedSharedTensor: Result of the operation. Raises: ValueError: Raised when private mul is performed parties!=3. """ y_tensor, session = self.sanity_checks(self, y) is_private = isinstance(y, ReplicatedSharedTensor) op_str = "mul" op = ReplicatedSharedTensor.get_op(self.ring_size, op_str) if is_private: if session.nr_parties == 3: from sympc.protocol import Falcon result = [ Falcon.multiplication_protocol(self, y_tensor, op_str) ] else: raise ValueError( "Private mult between ReplicatedSharedTensors is allowed only for 3 parties" ) else: result = [op(share, y_tensor.shares[0]) for share in self.shares] tensor = ReplicatedSharedTensor(ring_size=self.ring_size, session_uuid=self.session_uuid, config=self.config) tensor.shares = result return tensor def matmul( self, y: Union[int, float, torch.Tensor, "ReplicatedSharedTensor"] ) -> "ReplicatedSharedTensor": """Apply the "matmul" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]): self@y Returns: ReplicatedSharedTensor: Result of the operation. Raises: ValueError: Raised when private matmul is performed parties!=3. """ y_tensor, session = self.sanity_checks(self, y) is_private = isinstance(y, ReplicatedSharedTensor) op_str = "matmul" if is_private: if session.nr_parties == 3: from sympc.protocol import Falcon result = [ Falcon.multiplication_protocol(self, y_tensor, op_str) ] else: raise ValueError( "Private matmul between ReplicatedSharedTensors is allowed only for 3 parties" ) else: result = [ operator.matmul(share, y_tensor.shares[0]) for share in self.shares ] tensor = ReplicatedSharedTensor(ring_size=self.ring_size, session_uuid=self.session_uuid, config=self.config) tensor.shares = result return tensor def truediv(self, y: Union[int, torch.Tensor]) -> "ReplicatedSharedTensor": """Apply the "div" operation between "self" and "y". Args: y (Union[int , torch.Tensor]): Denominator. Returns: ReplicatedSharedTensor: Result of the operation. Raises: ValueError: If y is not an integer or LongTensor. """ if not isinstance(y, (int, torch.LongTensor)): raise ValueError( "Div works (for the moment) only with integers and LongTensor!" ) res = ReplicatedSharedTensor(session_uuid=self.session_uuid, config=self.config, ring_size=self.ring_size) res.shares = [share // y for share in self.shares] return res def rshift(self, y: int) -> "ReplicatedSharedTensor": """Apply the "rshift" operation to "self". Args: y (int): shift value Returns: ReplicatedSharedTensor: Result of the operation. Raises: ValueError: If y is not an integer. ValueError : If invalid shift value is provided. """ if not isinstance(y, int): raise ValueError("Right Shift works only with integers!") ring_bits = get_nr_bits(self.ring_size) if y < 0 or y > ring_bits - 1: raise ValueError( f"Invalid value for right shift: {y}, must be in range:[0,{ring_bits-1}]" ) res = ReplicatedSharedTensor(session_uuid=self.session_uuid, config=self.config, ring_size=self.ring_size) res.shares = [share >> y for share in self.shares] return res def bit_extraction(self, pos: int = 0) -> "ReplicatedSharedTensor": """Extracts the bit at the specified position. Args: pos (int): position to extract bit. Returns: ReplicatedSharedTensor : extracted bits at specific position. Raises: ValueError: If invalid position is provided. """ ring_bits = get_nr_bits(self.ring_size) if pos < 0 or pos > ring_bits - 1: raise ValueError( f"Invalid position for bit_extraction: {pos}, must be in range:[0,{ring_bits-1}]" ) shares = [] # logical shift bit_mask = torch.ones(self.shares[0].shape, dtype=self.shares[0].dtype) << pos shares = [share & bit_mask for share in self.shares] rst = ReplicatedSharedTensor( shares=shares, session_uuid=self.session_uuid, config=Config(encoder_base=1, encoder_precision=0), ring_size=2, ) return rst def rmatmul(self, y): """Apply the "rmatmul" operation between "y" and "self". Args: y: self@y Raises: NotImplementedError: Raised when implementation not present """ raise NotImplementedError def xor( self, y: Union[int, torch.Tensor, "ReplicatedSharedTensor"] ) -> "ReplicatedSharedTensor": """Apply the "xor" operation between "self" and "y". Args: y: public bit Returns: ReplicatedSharedTensor: Result of the operation. Raises: ValueError : If ring size is invalid. """ if self.ring_size == 2: return self + y elif self.ring_size in RING_SIZE_TO_TYPE: return self + y - (self * y * 2) else: raise ValueError( f"The ring_size {self.ring_size} is not supported.") def lt(self, y): """Lower than operator. Args: y: self<y Raises: NotImplementedError: Raised when implementation not present """ raise NotImplementedError def gt(self, y): """Greater than operator. Args: y: self>y Raises: NotImplementedError: Raised when implementation not present """ raise NotImplementedError def eq(self, y: Any) -> bool: """Equal operator. Check if "self" is equal with another object given a set of attributes to compare. Args: y (Any): Object to compare Returns: bool: True if equal False if not. """ if not (torch.cat(self.shares) == torch.cat(y.shares)).all(): return False if self.config != y.config: return False if self.session_uuid and y.session_uuid and self.session_uuid != y.session_uuid: return False if self.ring_size != y.ring_size: return False return True def __getitem__(self, key: int) -> torch.Tensor: """Allows to subset shares. Args: key (int): The share to be retrieved. Returns: share (torch.Tensor): Returned share. """ return self.shares[key] def __setitem__(self, key: int, newvalue: torch.Tensor) -> None: """Allows to set share value to new value. Args: key (int): The share to be retrieved. newvalue (torch.Tensor): New value of share. """ self.shares[key] = newvalue def ne(self, y): """Not Equal operator. Args: y: self!=y Raises: NotImplementedError: Raised when implementation not present """ raise NotImplementedError @staticmethod def shares_sum(shares: List[torch.Tensor], ring_size: int) -> torch.Tensor: """Returns sum of tensors based on ring_size. Args: shares (List[torch.Tensor]) : List of tensors. ring_size (int): Ring size of share associated with the tensors. Returns: value (torch.Tensor): sum of the tensors. """ if ring_size == 2: return reduce(lambda x, y: x ^ y, shares) elif ring_size == PRIME_NUMBER: return reduce(ReplicatedSharedTensor.addmodprime, shares) else: return sum(shares) @staticmethod def _request_and_get( share_ptr: "ReplicatedSharedTensor", ) -> "ReplicatedSharedTensor": """Function used to request and get a share - Duet Setup. Args: share_ptr (ReplicatedSharedTensor): input ReplicatedSharedTensor Returns: ReplicatedSharedTensor : The ReplicatedSharedTensor in local. """ if not ispointer(share_ptr): return share_ptr elif not islocal(share_ptr): share_ptr.request(block=True) res = share_ptr.get_copy() return res @staticmethod def __reconstruct_semi_honest( share_ptrs: List["ReplicatedSharedTensor"], get_shares: bool = False, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Reconstruct value from shares. Args: share_ptrs (List[ReplicatedSharedTensor]): List of RSTensor pointers. get_shares (bool): Retrieve only shares. Returns: reconstructed_value (torch.Tensor): Reconstructed value. """ request = ReplicatedSharedTensor._request_and_get request_wrap = parallel_execution(request) args = [[share] for share in share_ptrs[:2]] local_shares = request_wrap(args) shares = [local_shares[0].shares[0]] shares.extend(local_shares[1].shares) if get_shares: return shares ring_size = local_shares[0].ring_size return ReplicatedSharedTensor.shares_sum(shares, ring_size) @staticmethod def __reconstruct_malicious( share_ptrs: List["ReplicatedSharedTensor"], get_shares: bool = False, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Reconstruct value from shares. Args: share_ptrs (List[ReplicatedSharedTensor]): List of RSTensor pointers. get_shares (bool): Retrieve only shares. Returns: reconstructed_value (torch.Tensor): Reconstructed value. Raises: ValueError: When parties share values are not equal. """ nparties = len(share_ptrs) # Get shares from all parties request = ReplicatedSharedTensor._request_and_get request_wrap = parallel_execution(request) args = [[share] for share in share_ptrs] local_shares = request_wrap(args) ring_size = local_shares[0].ring_size shares_sum = ReplicatedSharedTensor.shares_sum all_shares = [rst.shares for rst in local_shares] # reconstruct shares from all parties and verify value = None for party_rank in range(nparties): tensor = shares_sum( [all_shares[party_rank][0]] + all_shares[(party_rank + 1) % (nparties)], ring_size, ) if value is None: value = tensor elif (tensor != value).any(): raise ValueError( "Reconstruction values from all parties are not equal.") if get_shares: return all_shares return value @staticmethod def reconstruct( share_ptrs: List["ReplicatedSharedTensor"], get_shares: bool = False, security_type: str = "semi-honest", ) -> Union[torch.Tensor, List[torch.Tensor]]: """Reconstruct value from shares. Args: share_ptrs (List[ReplicatedSharedTensor]): List of RSTensor pointers. security_type (str): Type of security followed by protocol. get_shares (bool): Retrieve only shares. Returns: reconstructed_value (torch.Tensor): Reconstructed value. Raises: ValueError: Invalid security type. ValueError : SharePointers not provided. """ if not len(share_ptrs): raise ValueError( "Share pointers must be provided for reconstruction.") if security_type == "malicious": return ReplicatedSharedTensor.__reconstruct_malicious( share_ptrs, get_shares) elif security_type == "semi-honest": return ReplicatedSharedTensor.__reconstruct_semi_honest( share_ptrs, get_shares) raise ValueError("Invalid security Type") @staticmethod def distribute_shares_to_party( shares: List[Union[ShareTensor, torch.Tensor]], party_rank: int, session: Session, ring_size: int, config: Config, ) -> "ReplicatedSharedTensor": """Distributes shares to party. Args: shares (List[Union[ShareTensor,torch.Tensor]]): Shares to be distributed. party_rank (int): Rank of party. session (Session): Current session ring_size (int): Ring size of tensor to distribute config (Config): The configuration(base,precision) of the tensor. Returns: tensor (ReplicatedSharedTensor): Tensor with shares Raises: TypeError: Invalid share class. """ party = session.parties[party_rank] nshares = session.nr_parties - 1 party_shares = [] for share_index in range(party_rank, party_rank + nshares): share = shares[share_index % (nshares + 1)] if isinstance(share, torch.Tensor): party_shares.append(share) elif isinstance(share, ShareTensor): party_shares.append(share.tensor) else: raise TypeError(f"{type(share)} is an invalid share class") tensor = ReplicatedSharedTensor( session_uuid=session.rank_to_uuid[party_rank], config=config, ring_size=ring_size, ) tensor.shares = party_shares return tensor.send(party) @staticmethod def distribute_shares( shares: List[Union[ShareTensor, torch.Tensor]], session: Session, ring_size: Optional[int] = None, config: Optional[Config] = None, ) -> List["ReplicatedSharedTensor"]: """Distribute a list of shares. Args: shares (List[ShareTensor): list of shares to distribute. session (Session): Session. ring_size (int): ring_size the shares belong to. config (Config): The configuration(base,precision) of the tensor. Returns: List of ReplicatedSharedTensors. Raises: TypeError: when Datatype of shares is invalid. """ if not isinstance(shares, (list, tuple)): raise TypeError( "Shares to be distributed should be a list of shares") if len(shares) != session.nr_parties: return ValueError( "Number of shares to be distributed should be same as number of parties" ) if ring_size is None: ring_size = session.ring_size if config is None: config = session.config args = [[shares, party_rank, session, ring_size, config] for party_rank in range(session.nr_parties)] return [ ReplicatedSharedTensor.distribute_shares_to_party(*arg) for arg in args ] @staticmethod def hook_property(property_name: str) -> Any: """Hook a framework property (only getter). Ex: * if we call "shape" we want to call it on the underlying tensor and return the result * if we call "T" we want to call it on the underlying tensor but we want to wrap it in the same tensor type Args: property_name (str): property to hook Returns: A hooked property """ def property_new_rs_tensor_getter( _self: "ReplicatedSharedTensor") -> Any: shares = [] for share in _self.shares: tensor = getattr(share, property_name) shares.append(tensor) res = ReplicatedSharedTensor( session_uuid=_self.session_uuid, config=_self.config, ring_size=_self.ring_size, ) res.shares = shares return res def property_getter(_self: "ReplicatedSharedTensor") -> Any: prop = getattr(_self.shares[0], property_name) return prop if property_name in PROPERTIES_NEW_RS_TENSOR: res = property(property_new_rs_tensor_getter, None) else: res = property(property_getter, None) return res @staticmethod def hook_method(method_name: str) -> Callable[..., Any]: """Hook a framework method such that we know how to treat it given that we call it. Ex: * if we call "numel" we want to call it on the underlying tensor and return the result * if we call "unsqueeze" we want to call it on the underlying tensor but we want to wrap it in the same tensor type Args: method_name (str): method to hook Returns: A hooked method """ def method_new_rs_tensor(_self: "ReplicatedSharedTensor", *args: List[Any], **kwargs: Dict[Any, Any]) -> Any: shares = [] for share in _self.shares: tensor = getattr(share, method_name)(*args, **kwargs) shares.append(tensor) res = ReplicatedSharedTensor( session_uuid=_self.session_uuid, config=_self.config, ring_size=_self.ring_size, ) res.shares = shares return res def method(_self: "ReplicatedSharedTensor", *args: List[Any], **kwargs: Dict[Any, Any]) -> Any: method = getattr(_self.shares[0], method_name) res = method(*args, **kwargs) return res if method_name in METHODS_NEW_RS_TENSOR: res = method_new_rs_tensor else: res = method return res __add__ = add __radd__ = add __sub__ = sub __rsub__ = rsub __mul__ = mul __rmul__ = mul __matmul__ = matmul __rmatmul__ = rmatmul __truediv__ = truediv __floordiv__ = truediv __xor__ = xor __eq__ = eq __rshift__ = rshift
class ShareTensor: """This class represents 1 share that a party holds when doing secret sharing. Attributes: Syft Serializable Attributes id (UID): the id to store the session tags (Optional[List[str]): an optional list of strings that are tags used at search description (Optional[str]): an optional string used to describe the session tensor (Any): the value of the share session (Session): keep track from which session this share belongs to fp_encoder (FixedPointEncoder): the encoder used to convert a share from/to fixed point """ __slots__ = { # Populated in Syft "id", "tags", "description", "tensor", "session", "fp_encoder", } def __init__( self, data: Optional[Union[float, int, torch.Tensor]] = None, session: Optional[Session] = None, encoder_base: int = 2, encoder_precision: int = 16, ring_size: int = 2**64, ) -> None: """Initialize ShareTensor. Args: data (Optional[Any]): The share a party holds. Defaults to None session (Optional[Any]): The session from which this shares belongs to. Defaults to None. encoder_base (int): The base for the encoder. Defaults to 2. encoder_precision (int): the precision for the encoder. Defaults to 16. ring_size (int): field used for the operations applied on the shares Defaults to 2**64 """ if session is None: self.session = Session(ring_size=ring_size, ) self.session.config.encoder_precision = encoder_precision self.session.config.encoder_base = encoder_base else: self.session = session encoder_precision = self.session.config.encoder_precision encoder_base = self.session.config.encoder_base # TODO: It looks like the same logic as above self.fp_encoder = FixedPointEncoder(base=encoder_base, precision=encoder_precision) self.tensor: Optional[torch.Tensor] = None if data is not None: tensor_type = self.session.tensor_type self.tensor = self._encode(data).type(tensor_type) def _encode(self, data): return self.fp_encoder.encode(data) def decode(self): """Decode via FixedPrecisionEncoder. Returns: torch.Tensor: Decoded value """ return self._decode() def _decode(self): return self.fp_encoder.decode(self.tensor.type(torch.LongTensor)) @staticmethod def sanity_checks(x: "ShareTensor", y: Union[int, float, torch.Tensor, "ShareTensor"], op_str: str) -> "ShareTensor": """Check the type of "y" and covert it to share if necessary. Args: x (ShareTensor): Typically "self". y (Union[int, float, torch.Tensor, "ShareTensor"]): Tensor to check. op_str (str): String operator. Returns: ShareTensor: the converted y value. """ if not isinstance(y, ShareTensor): y = ShareTensor(data=y, session=x.session) return y def apply_function(self, y: Union["ShareTensor", torch.Tensor, int, float], op_str: str) -> "ShareTensor": """Apply a given operation. Args: y (Union["ShareTensor", torch.Tensor, int, float]): tensor to apply the operator. op_str (str): Operator. Returns: ShareTensor: Result of the operation. """ op = getattr(operator, op_str) if isinstance(y, ShareTensor): value = op(self.tensor, y.tensor) else: value = op(self.tensor, y) res = ShareTensor(session=self.session) res.tensor = value return res def add(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "add" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): self + y Returns: ShareTensor. Result of the operation. """ y_share = ShareTensor.sanity_checks(self, y, "add") res = self.apply_function(y_share, "add") return res def sub(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "sub" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): self - y Returns: ShareTensor. Result of the operation. """ y_share = ShareTensor.sanity_checks(self, y, "sub") res = self.apply_function(y_share, "sub") return res def rsub( self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "sub" operation between "y" and "self". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): y - self Returns: ShareTensor. Result of the operation. """ y_share = ShareTensor.sanity_checks(self, y, "sub") res = y_share.apply_function(self, "sub") return res def mul(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "mul" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): self * y Returns: ShareTensor. Result of the operation. """ y = ShareTensor.sanity_checks(self, y, "mul") res = self.apply_function(y, "mul") if self.session.nr_parties == 0: # We are using a simple share without usig the MPCTensor # In case we used the MPCTensor - the division would have # been done in the protocol res.tensor = res.tensor // self.fp_encoder.scale return res def xor(self, y: Union[int, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "xor" operation between "self" and "y". Args: y (Union[int, torch.Tensor, "ShareTensor"]): self xor y Returns: ShareTensor: Result of the operation. """ res = ShareTensor(session=self.session) if isinstance(y, ShareTensor): res.tensor = self.tensor ^ y.tensor else: res.tensor = self.tensor ^ y return res def matmul( self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "matmul" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): self @ y. Returns: ShareTensor: Result of the operation. """ y = ShareTensor.sanity_checks(self, y, "matmul") res = self.apply_function(y, "matmul") if self.session.nr_parties == 0: # We are using a simple share without usig the MPCTensor # In case we used the MPCTensor - the division would have # been done in the protocol res.tensor = res.tensor // self.fp_encoder.scale return res def rmatmul(self, y: torch.Tensor) -> "ShareTensor": """Apply the "rmatmul" operation between "y" and "self". Args: y (torch.Tensor): y @ self Returns: ShareTensor. Result of the operation. """ y = ShareTensor.sanity_checks(self, y, "matmul") return y.matmul(self) def div(self, y: Union[int, float, torch.Tensor, "ShareTensor"]) -> "ShareTensor": """Apply the "div" operation between "self" and "y". Args: y (Union[int, float, torch.Tensor, "ShareTensor"]): Denominator. Returns: ShareTensor: Result of the operation. Raises: ValueError: If y is not an integer or LongTensor. """ if not isinstance(y, (int, torch.LongTensor)): raise ValueError("Div works (for the moment) only with integers!") res = ShareTensor(session=self.session) res.tensor = self.tensor // y return res def __getattr__(self, attr_name: str) -> Any: """Access to tensor attributes. Args: attr_name (str): Name of the attribute. Returns: Any: Attribute. """ tensor = self.tensor res = getattr(tensor, attr_name) return res def __gt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool: """Greater than operator. Args: y (Union["ShareTensor", torch.Tensor, int]): Tensor to compare. Returns: bool: Result of the comparison. """ y_share = ShareTensor.sanity_checks(self, y, "gt") res = self.tensor > y_share.tensor return res def __lt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool: """Lower than operator. Args: y (Union["ShareTensor", torch.Tensor, int]): Tensor to compare. Returns: bool: Result of the comparison. """ y_share = ShareTensor.sanity_checks(self, y, "lt") res = self.tensor < y_share.tensor return res def __str__(self) -> str: """Representation. Returns: str: Return the string representation of ShareTensor. """ type_name = type(self).__name__ out = f"[{type_name}]" out = f"{out}\n\t| {self.fp_encoder}" out = f"{out}\n\t| Data: {self.tensor}" return out def __repr__(self) -> str: """Representation. Returns: String representation. """ return self.__str__() def __eq__(self, other: Any) -> bool: """Equal operator. Check if "self" is equal with another object given a set of attributes to compare. Args: other (Any): Tensor to compare. Returns: bool: True if equal False if not. """ if not (self.tensor == other.tensor).all(): return False if not (self.session == other.session): return False return True # Forward to tensor methods @property def shape(self) -> Any: """Shape of the tensor. Returns: Any: Shape of the tensor. """ return self.tensor.shape def numel(self, *args: List[Any], **kwargs: Dict[Any, Any]) -> Any: """Total number of elements. Args: *args: Arguments passed to tensor.numel. **kwargs: Keyword arguments passed to tensor.numel. Returns: Any: Total number of elements of the tensor. """ return self.tensor.numel(*args, **kwargs) @property def T(self) -> Any: """Transpose. Returns: Any: ShareTensor transposed. """ res = ShareTensor(session=self.session) res.tensor = self.tensor.T return res def unsqueeze(self, *args: List[Any], **kwargs: Dict[Any, Any]) -> Any: """Tensor with a dimension of size one inserted at the specified position. Args: *args: Arguments to tensor.unsqueeze **kwargs: Keyword arguments passed to tensor.unsqueeze Returns: Any: ShareTensor unsqueezed. References: https://pytorch.org/docs/stable/generated/torch.unsqueeze.html """ tensor = self.tensor.unsqueeze(*args, **kwargs) res = ShareTensor(session=self.session) res.tensor = tensor return res def view(self, *args: List[Any], **kwargs: Dict[Any, Any]) -> Any: """Tensor with the same data but new dimensions/view. Args: *args: Arguments to tensor.view. **kwargs: Keyword arguments passed to tensor.view. Returns: Any: ShareTensor with new view. References: https://pytorch.org/docs/stable/tensors.html#torch.Tensor.view """ tensor = self.tensor.view(*args, **kwargs) res = ShareTensor(session=self.session) res.tensor = tensor return res __add__ = add __radd__ = add __sub__ = sub __rsub__ = rsub __mul__ = mul __rmul__ = mul __matmul__ = matmul __rmatmul__ = rmatmul __truediv__ = div __xor__ = xor