Esempio n. 1
0
def test_fp_encoding():
    """Test correct encoding with FixedPointEncoder."""
    fp_encoder = FixedPointEncoder()
    # Test encoding with tensors.
    tensor = torch.Tensor([1, 2, 3])
    encoded_tensor = fp_encoder.encode(tensor)
    target_tensor = torch.LongTensor([1, 2, 3]) * fp_encoder.scale
    assert (encoded_tensor == target_tensor).all()
    # Test encoding with foats.
    test_float = 42.0
    encoded_float = fp_encoder.encode(test_float)
    target_float = torch.LongTensor([42]) * fp_encoder.scale
    assert (encoded_float == target_float).all()
    # Test encoding with ints.
    test_int = 2
    encoded_int = fp_encoder.encode(test_int)
    target_int = torch.LongTensor([2]) * fp_encoder.scale
    assert (encoded_int == target_int).all()
def test_fixed_point(precision, base) -> None:
    x = torch.tensor([1.25, 3.301])
    shares = [x, x]
    rst = ReplicatedSharedTensor(shares=shares,
                                 config=Config(encoder_precision=precision,
                                               encoder_base=base))
    fp_encoder = FixedPointEncoder(precision=precision, base=base)
    tensor_type = get_type_from_ring(rst.ring_size)
    for i in range(len(shares)):
        shares[i] = fp_encoder.encode(shares[i]).to(tensor_type)

    assert (torch.cat(shares) == torch.cat(rst.shares)).all()

    for i in range(len(shares)):
        shares[i] = fp_encoder.decode(shares[i].type(torch.LongTensor))

    assert (torch.cat(shares) == torch.cat(rst.decode())).all()
Esempio n. 3
0
def test_private_compare(get_clients, security) -> None:
    parties = get_clients(3)
    falcon = Falcon(security_type=security)
    session = Session(parties=parties, protocol=falcon)
    SessionManager.setup_mpc(session)
    base = session.config.encoder_base
    precision = session.config.encoder_precision
    fp_encoder = FixedPointEncoder(base=base, precision=precision)

    secret = torch.tensor([[358.85, 79.29], [67.78, 2415.50]])
    r = torch.tensor([[357.05, 90], [145.32, 2400.54]])
    r = fp_encoder.encode(r)
    x = MPCTensor(secret=secret, session=session)
    x_b = ABY3.bit_decomposition_ttp(x, session)  # bit shares
    x_p = []  # prime ring shares
    for share in x_b:
        x_p.append(ABY3.bit_injection(share, session, PRIME_NUMBER))

    tensor_type = get_type_from_ring(session.ring_size)
    result = Falcon.private_compare(x_p, r.type(tensor_type))
    expected_res = torch.tensor([[1, 0], [0, 1]], dtype=torch.bool)
    assert (result.reconstruct(decode=False) == expected_res).all()
Esempio n. 4
0
class ShareTensor(metaclass=SyMPCTensor):
    """Single Share representation.

    Arguments:
        data (Optional[Any]): the share a party holds
        session (Optional[Any]): the session from which this shares belongs to
        encoder_base (int): the base for the encoder
        encoder_precision (int): the precision for the encoder
        ring_size (int): field used for the operations applied on the shares

    Attributes:
        Syft Serializable Attributes

        id (UID): the id to store the session
        tags (Optional[List[str]): an optional list of strings that are tags used at search
        description (Optional[str]): an optional string used to describe the session

        tensor (Any): the value of the share
        session_uuid (Optional[UUID]): keep track from which session the share belongs to
        encoder_precision (int): precision for the encoder
        encoder_base (int): base for the encoder
    """

    __slots__ = {
        # Populated in Syft
        "id",
        "tags",
        "description",
        "tensor",
        "session_uuid",
        "config",
        "fp_encoder",
        "ring_size",
    }

    # Used by the SyMPCTensor metaclass
    METHODS_FORWARD: Set[str] = {
        "numel",
        "squeeze",
        "unsqueeze",
        "t",
        "view",
        "expand",
        "sum",
        "clone",
        "flatten",
        "reshape",
        "repeat",
        "narrow",
        "dim",
        "transpose",
        "roll",
    }
    PROPERTIES_FORWARD: Set[str] = {"T", "shape"}

    def __init__(
        self,
        data: Optional[Union[float, int, torch.Tensor]] = None,
        config: Config = Config(encoder_base=2, encoder_precision=16),
        session_uuid: Optional[UUID] = None,
        ring_size: int = 2**64,
    ) -> None:
        """Initialize ShareTensor.

        Args:
            data (Optional[Any]): The share a party holds. Defaults to None
            config (Config): The configuration where we keep the encoder precision and base.
            session_uuid (Optional[UUID]): Used to keep track of a share that is associated with a
                remote session
            ring_size (int): field used for the operations applied on the shares
                Defaults to 2**64
        """
        self.session_uuid = session_uuid
        self.ring_size = ring_size

        self.config = config
        self.fp_encoder = FixedPointEncoder(base=config.encoder_base,
                                            precision=config.encoder_precision)

        self.tensor: Optional[torch.Tensor] = None
        if data is not None:
            tensor_type = get_type_from_ring(ring_size)
            self.tensor = self._encode(data).to(tensor_type)

    def _encode(self, data):
        return self.fp_encoder.encode(data)

    def decode(self):
        """Decode via FixedPrecisionEncoder.

        Returns:
            torch.Tensor: Decoded value
        """
        return self._decode()

    def _decode(self):
        return self.fp_encoder.decode(self.tensor.type(torch.LongTensor))

    @staticmethod
    def sanity_checks(x: "ShareTensor", y: Union[int, float, torch.Tensor,
                                                 "ShareTensor"],
                      op_str: str) -> "ShareTensor":
        """Check the type of "y" and covert it to share if necessary.

        Args:
            x (ShareTensor): Typically "self".
            y (Union[int, float, torch.Tensor, "ShareTensor"]): Tensor to check.
            op_str (str): String operator.

        Returns:
            ShareTensor: the converted y value.

        Raises:
            ValueError: if both values are shares and they have different uuids
        """
        if not isinstance(y, ShareTensor):
            if x.session_uuid is not None:
                session = sympc.session.get_session(str(x.session_uuid))
                ring_size = session.ring_size
                config = session.config
            else:
                ring_size = x.ring_size
                config = x.config

            y = ShareTensor(data=y, ring_size=ring_size, config=config)

        elif y.session_uuid and x.session_uuid and y.session_uuid != x.session_uuid:
            raise ValueError(
                f"Session UUIDs did not match {x.session_uuid} {y.session_uuid}"
            )

        return y

    def apply_function(self, y: Union["ShareTensor", torch.Tensor, int, float],
                       op_str: str) -> "ShareTensor":
        """Apply a given operation.

        Args:
            y (Union["ShareTensor", torch.Tensor, int, float]): tensor to apply the operator.
            op_str (str): Operator.

        Returns:
            ShareTensor: Result of the operation.
        """
        op = getattr(operator, op_str)

        if isinstance(y, ShareTensor):
            value = op(self.tensor, y.tensor)
        else:
            value = op(self.tensor, y)

        session_uuid = self.session_uuid or y.session_uuid
        if session_uuid is not None:
            session = sympc.session.get_session(str(session_uuid))
            ring_size = session.ring_size
            config = session.config
        else:
            # Use the values from "self"
            ring_size = self.ring_size
            config = self.config

        res = ShareTensor(ring_size=ring_size,
                          session_uuid=session_uuid,
                          config=config)
        res.tensor = value
        return res

    def add(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "add" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): self + y

        Returns:
            ShareTensor. Result of the operation.
        """
        y_share = ShareTensor.sanity_checks(self, y, "add")
        res = self.apply_function(y_share, "add")
        return res

    def sub(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "sub" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): self - y

        Returns:
            ShareTensor. Result of the operation.
        """
        y_share = ShareTensor.sanity_checks(self, y, "sub")
        res = self.apply_function(y_share, "sub")
        return res

    def rsub(
            self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "sub" operation between "y" and "self".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): y - self

        Returns:
            ShareTensor. Result of the operation.
        """
        y_share = ShareTensor.sanity_checks(self, y, "sub")
        res = y_share.apply_function(self, "sub")
        return res

    def mul(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "mul" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): self * y

        Returns:
            ShareTensor. Result of the operation.
        """
        y = ShareTensor.sanity_checks(self, y, "mul")
        res = self.apply_function(y, "mul")

        if self.session_uuid is None and y.session_uuid is None:
            # We are using a simple share without usig the MPCTensor
            # In case we used the MPCTensor - the division would have
            # been done in the protocol
            res.tensor //= self.fp_encoder.scale

        return res

    def xor(self, y: Union[int, torch.Tensor, "ShareTensor"]) -> "ShareTensor":
        """Apply the "xor" operation between "self" and "y".

        Args:
            y (Union[int, torch.Tensor, "ShareTensor"]): self xor y

        Returns:
            ShareTensor: Result of the operation.
        """
        res = self.apply_function(y, "xor")
        return res

    def matmul(
            self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "matmul" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): self @ y.

        Returns:
            ShareTensor: Result of the operation.
        """
        y = ShareTensor.sanity_checks(self, y, "matmul")
        res = self.apply_function(y, "matmul")

        if self.session_uuid is None and y.session_uuid is None:
            # We are using a simple share without usig the MPCTensor
            # In case we used the MPCTensor - the division would have
            # been done in the protocol
            res.tensor = res.tensor // self.fp_encoder.scale

        return res

    def rmatmul(self, y: torch.Tensor) -> "ShareTensor":
        """Apply the "rmatmul" operation between "y" and "self".

        Args:
            y (torch.Tensor): y @ self

        Returns:
            ShareTensor. Result of the operation.
        """
        y = ShareTensor.sanity_checks(self, y, "matmul")
        return y.matmul(self)

    def div(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "div" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): Denominator.

        Returns:
            ShareTensor: Result of the operation.

        Raises:
            ValueError: If y is not an integer or LongTensor.
        """
        if not isinstance(y, (int, torch.LongTensor)):
            raise ValueError("Div works (for the moment) only with integers!")

        res = ShareTensor(session_uuid=self.session_uuid, config=self.config)
        # res = self.apply_function(y, "floordiv")
        res.tensor = self.tensor // y
        return res

    def __gt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool:
        """Greater than operator.

        Args:
            y (Union["ShareTensor", torch.Tensor, int]): Tensor to compare.

        Returns:
            bool: Result of the comparison.
        """
        y_share = ShareTensor.sanity_checks(self, y, "gt")
        res = self.tensor > y_share.tensor
        return res

    def __lt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool:
        """Lower than operator.

        Args:
            y (Union["ShareTensor", torch.Tensor, int]): Tensor to compare.

        Returns:
            bool: Result of the comparison.
        """
        y_share = ShareTensor.sanity_checks(self, y, "lt")
        res = self.tensor < y_share.tensor
        return res

    def __str__(self) -> str:
        """Representation.

        Returns:
            str: Return the string representation of ShareTensor.
        """
        type_name = type(self).__name__
        out = f"[{type_name}]"
        out = f"{out}\n\t| Session UUID: {self.session_uuid}"
        out = f"{out}\n\t| {self.fp_encoder}"
        out = f"{out}\n\t| Data: {self.tensor}"

        return out

    def __repr__(self) -> str:
        """Representation.

        Returns:
            String representation.
        """
        return self.__str__()

    def __eq__(self, other: Any) -> bool:
        """Equal operator.

        Check if "self" is equal with another object given a set of
            attributes to compare.

        Args:
            other (Any): Tensor to compare.

        Returns:
            bool: True if equal False if not.

        """
        if not (self.tensor == other.tensor).all():
            return False

        if not self.config == other.config:
            return False

        if (self.session_uuid and other.session_uuid
                and self.session_uuid != other.session_uuid):
            # If both shares have a session_uuid we consider them not equal
            # else they are
            return False

        return True

    @staticmethod
    def hook_property(property_name: str) -> Any:
        """Hook a framework property (only getter).

        Ex:
         * if we call "shape" we want to call it on the underlying tensor
        and return the result
         * if we call "T" we want to call it on the underlying tensor
        but we want to wrap it in the same tensor type

        Args:
            property_name (str): property to hook

        Returns:
            A hooked property
        """
        def property_new_share_tensor_getter(_self: "ShareTensor") -> Any:
            tensor = getattr(_self.tensor, property_name)
            res = ShareTensor(session_uuid=_self.session_uuid,
                              config=_self.config)
            res.tensor = tensor
            return res

        def property_getter(_self: "ShareTensor") -> Any:
            prop = getattr(_self.tensor, property_name)
            return prop

        if property_name in PROPERTIES_NEW_SHARE_TENSOR:
            res = property(property_new_share_tensor_getter, None)
        else:
            res = property(property_getter, None)

        return res

    @staticmethod
    def hook_method(method_name: str) -> Callable[..., Any]:
        """Hook a framework method such that we know how to treat it given that we call it.

        Ex:
         * if we call "numel" we want to call it on the underlying tensor
        and return the result
         * if we call "unsqueeze" we want to call it on the underlying tensor
        but we want to wrap it in the same tensor type

        Args:
            method_name (str): method to hook

        Returns:
            A hooked method
        """
        def method_new_share_tensor(_self: "ShareTensor", *args: List[Any],
                                    **kwargs: Dict[Any, Any]) -> Any:
            method = getattr(_self.tensor, method_name)
            tensor = method(*args, **kwargs)
            res = ShareTensor(session_uuid=_self.session_uuid,
                              config=_self.config)
            res.tensor = tensor
            return res

        def method(_self: "ShareTensor", *args: List[Any],
                   **kwargs: Dict[Any, Any]) -> Any:
            method = getattr(_self.tensor, method_name)
            res = method(*args, **kwargs)
            return res

        if method_name in METHODS_NEW_SHARE_TENSOR:
            res = method_new_share_tensor
        else:
            res = method

        return res

    @staticmethod
    def reconstruct(
        share_ptrs: List["ShareTensor"],
        get_shares=False,
        security_type: str = "semi-honest",
    ) -> torch.Tensor:
        """Reconstruct original value from shares.

        Args:
            share_ptrs (List[ShareTensor]): List of sharetensors.
            get_shares (boolean): retrieve shares or reconstructed value.
            security_type (str): Type of security by protocol.

        Returns:
            plaintext/shares (torch.Tensor/List[torch.Tensors]): Plaintext or list of shares.

        """
        def _request_and_get(share_ptr: ShareTensor) -> ShareTensor:
            """Function used to request and get a share - Duet Setup.

            Args:
                share_ptr (ShareTensor): a ShareTensor

            Returns:
                ShareTensor. The ShareTensor in local.

            """
            if not islocal(share_ptr):
                share_ptr.request(block=True)
            res = share_ptr.get_copy()
            return res

        request = _request_and_get
        request_wrap = parallel_execution(request)

        args = [[share] for share in share_ptrs]
        local_shares = request_wrap(args)

        shares = [share.tensor for share in local_shares]

        if get_shares:
            return shares

        plaintext = sum(shares)

        return plaintext

    @staticmethod
    def distribute_shares(shares: List["ShareTensor"], session: Session):
        """Distribute a list of shares.

        Args:
            shares (List[ShareTensor): list of shares to distribute.
            session (Session): Session for which those shares were generated

        Returns:
            List of ShareTensorPointers.
        """
        rank_to_uuid = session.rank_to_uuid
        parties = session.parties

        share_ptrs = []
        for rank, share in enumerate(shares):
            share.session_uuid = rank_to_uuid[rank]
            party = parties[rank]
            share_ptrs.append(share.send(party))

        return share_ptrs

    __add__ = add
    __radd__ = add
    __sub__ = sub
    __rsub__ = rsub
    __mul__ = mul
    __rmul__ = mul
    __matmul__ = matmul
    __rmatmul__ = rmatmul
    __truediv__ = div
    __xor__ = xor
Esempio n. 5
0
class ShareTensor:
    """This class represents 1 share that a party holds when doing secret sharing.

    Attributes:
        Syft Serializable Attributes

        id (UID): the id to store the session
        tags (Optional[List[str]): an optional list of strings that are tags used at search
        description (Optional[str]): an optional string used to describe the session

        tensor (Any): the value of the share
        session (Session): keep track from which session  this share belongs to
        fp_encoder (FixedPointEncoder): the encoder used to convert a share from/to fixed point
    """

    __slots__ = {
        # Populated in Syft
        "id",
        "tags",
        "description",
        "tensor",
        "session",
        "fp_encoder",
    }

    def __init__(
        self,
        data: Optional[Union[float, int, torch.Tensor]] = None,
        session: Optional[Session] = None,
        encoder_base: int = 2,
        encoder_precision: int = 16,
        ring_size: int = 2**64,
    ) -> None:
        """Initialize ShareTensor.

        Args:
            data (Optional[Any]): The share a party holds. Defaults to None
            session (Optional[Any]): The session from which this shares belongs to.
                Defaults to None.
            encoder_base (int): The base for the encoder. Defaults to 2.
            encoder_precision (int): the precision for the encoder. Defaults to 16.
            ring_size (int): field used for the operations applied on the shares
                Defaults to 2**64
        """
        if session is None:
            self.session = Session(ring_size=ring_size, )
            self.session.config.encoder_precision = encoder_precision
            self.session.config.encoder_base = encoder_base

        else:
            self.session = session
            encoder_precision = self.session.config.encoder_precision
            encoder_base = self.session.config.encoder_base

        # TODO: It looks like the same logic as above
        self.fp_encoder = FixedPointEncoder(base=encoder_base,
                                            precision=encoder_precision)

        self.tensor: Optional[torch.Tensor] = None

        if data is not None:
            tensor_type = self.session.tensor_type
            self.tensor = self._encode(data).type(tensor_type)

    def _encode(self, data):
        return self.fp_encoder.encode(data)

    def decode(self):
        """Decode via FixedPrecisionEncoder.

        Returns:
            torch.Tensor: Decoded value
        """
        return self._decode()

    def _decode(self):
        return self.fp_encoder.decode(self.tensor.type(torch.LongTensor))

    @staticmethod
    def sanity_checks(x: "ShareTensor", y: Union[int, float, torch.Tensor,
                                                 "ShareTensor"],
                      op_str: str) -> "ShareTensor":
        """Check the type of "y" and covert it to share if necessary.

        Args:
            x (ShareTensor): Typically "self".
            y (Union[int, float, torch.Tensor, "ShareTensor"]): Tensor to check.
            op_str (str): String operator.

        Returns:
            ShareTensor: the converted y value.

        """
        if not isinstance(y, ShareTensor):
            y = ShareTensor(data=y, session=x.session)

        return y

    def apply_function(self, y: Union["ShareTensor", torch.Tensor, int, float],
                       op_str: str) -> "ShareTensor":
        """Apply a given operation.

        Args:
            y (Union["ShareTensor", torch.Tensor, int, float]): tensor to apply the operator.
            op_str (str): Operator.

        Returns:
            ShareTensor: Result of the operation.
        """
        op = getattr(operator, op_str)

        if isinstance(y, ShareTensor):
            value = op(self.tensor, y.tensor)
        else:
            value = op(self.tensor, y)

        res = ShareTensor(session=self.session)
        res.tensor = value
        return res

    def add(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "add" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): self + y

        Returns:
            ShareTensor. Result of the operation.
        """
        y_share = ShareTensor.sanity_checks(self, y, "add")
        res = self.apply_function(y_share, "add")
        return res

    def sub(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "sub" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): self - y

        Returns:
            ShareTensor. Result of the operation.
        """
        y_share = ShareTensor.sanity_checks(self, y, "sub")
        res = self.apply_function(y_share, "sub")
        return res

    def rsub(
            self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "sub" operation between "y" and "self".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): y - self

        Returns:
            ShareTensor. Result of the operation.
        """
        y_share = ShareTensor.sanity_checks(self, y, "sub")
        res = y_share.apply_function(self, "sub")
        return res

    def mul(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "mul" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): self * y

        Returns:
            ShareTensor. Result of the operation.
        """
        y = ShareTensor.sanity_checks(self, y, "mul")
        res = self.apply_function(y, "mul")

        if self.session.nr_parties == 0:
            # We are using a simple share without usig the MPCTensor
            # In case we used the MPCTensor - the division would have
            # been done in the protocol
            res.tensor = res.tensor // self.fp_encoder.scale

        return res

    def xor(self, y: Union[int, torch.Tensor, "ShareTensor"]) -> "ShareTensor":
        """Apply the "xor" operation between "self" and "y".

        Args:
            y (Union[int, torch.Tensor, "ShareTensor"]): self xor y

        Returns:
            ShareTensor: Result of the operation.
        """
        res = ShareTensor(session=self.session)

        if isinstance(y, ShareTensor):
            res.tensor = self.tensor ^ y.tensor
        else:
            res.tensor = self.tensor ^ y

        return res

    def matmul(
            self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "matmul" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): self @ y.

        Returns:
            ShareTensor: Result of the operation.
        """
        y = ShareTensor.sanity_checks(self, y, "matmul")
        res = self.apply_function(y, "matmul")

        if self.session.nr_parties == 0:
            # We are using a simple share without usig the MPCTensor
            # In case we used the MPCTensor - the division would have
            # been done in the protocol
            res.tensor = res.tensor // self.fp_encoder.scale

        return res

    def rmatmul(self, y: torch.Tensor) -> "ShareTensor":
        """Apply the "rmatmul" operation between "y" and "self".

        Args:
            y (torch.Tensor): y @ self

        Returns:
            ShareTensor. Result of the operation.
        """
        y = ShareTensor.sanity_checks(self, y, "matmul")
        return y.matmul(self)

    def div(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "div" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): Denominator.

        Returns:
            ShareTensor: Result of the operation.

        Raises:
            ValueError: If y is not an integer or LongTensor.
        """
        if not isinstance(y, (int, torch.LongTensor)):
            raise ValueError("Div works (for the moment) only with integers!")

        res = ShareTensor(session=self.session)
        res.tensor = self.tensor // y

        return res

    def __getattr__(self, attr_name: str) -> Any:
        """Access to tensor attributes.

        Args:
            attr_name (str): Name of the attribute.

        Returns:
            Any: Attribute.
        """
        tensor = self.tensor
        res = getattr(tensor, attr_name)
        return res

    def __gt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool:
        """Greater than operator.

        Args:
            y (Union["ShareTensor", torch.Tensor, int]): Tensor to compare.

        Returns:
            bool: Result of the comparison.
        """
        y_share = ShareTensor.sanity_checks(self, y, "gt")
        res = self.tensor > y_share.tensor
        return res

    def __lt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool:
        """Lower than operator.

        Args:
            y (Union["ShareTensor", torch.Tensor, int]): Tensor to compare.

        Returns:
            bool: Result of the comparison.
        """
        y_share = ShareTensor.sanity_checks(self, y, "lt")
        res = self.tensor < y_share.tensor
        return res

    def __str__(self) -> str:
        """Representation.

        Returns:
            str: Return the string representation of ShareTensor.
        """
        type_name = type(self).__name__
        out = f"[{type_name}]"
        out = f"{out}\n\t| {self.fp_encoder}"
        out = f"{out}\n\t| Data: {self.tensor}"

        return out

    def __repr__(self) -> str:
        """Representation.

        Returns:
            String representation.
        """
        return self.__str__()

    def __eq__(self, other: Any) -> bool:
        """Equal operator.

        Check if "self" is equal with another object given a set of
            attributes to compare.

        Args:
            other (Any): Tensor to compare.

        Returns:
            bool: True if equal False if not.

        """
        if not (self.tensor == other.tensor).all():
            return False

        if not (self.session == other.session):
            return False

        return True

    # Forward to tensor methods

    @property
    def shape(self) -> Any:
        """Shape of the tensor.

        Returns:
            Any: Shape of the tensor.
        """
        return self.tensor.shape

    def numel(self, *args: List[Any], **kwargs: Dict[Any, Any]) -> Any:
        """Total number of elements.

        Args:
            *args: Arguments passed to tensor.numel.
            **kwargs: Keyword arguments passed to tensor.numel.

        Returns:
            Any: Total number of elements of the tensor.

        """
        return self.tensor.numel(*args, **kwargs)

    @property
    def T(self) -> Any:
        """Transpose.

        Returns:
            Any: ShareTensor transposed.

        """
        res = ShareTensor(session=self.session)
        res.tensor = self.tensor.T
        return res

    def unsqueeze(self, *args: List[Any], **kwargs: Dict[Any, Any]) -> Any:
        """Tensor with a dimension of size one inserted at the specified position.

        Args:
            *args: Arguments to tensor.unsqueeze
            **kwargs: Keyword arguments passed to tensor.unsqueeze

        Returns:
            Any: ShareTensor unsqueezed.

        References:
            https://pytorch.org/docs/stable/generated/torch.unsqueeze.html
        """
        tensor = self.tensor.unsqueeze(*args, **kwargs)
        res = ShareTensor(session=self.session)
        res.tensor = tensor
        return res

    def view(self, *args: List[Any], **kwargs: Dict[Any, Any]) -> Any:
        """Tensor with the same data but new dimensions/view.

        Args:
            *args: Arguments to tensor.view.
            **kwargs: Keyword arguments passed to tensor.view.

        Returns:
            Any: ShareTensor with new view.

        References:
            https://pytorch.org/docs/stable/tensors.html#torch.Tensor.view
        """
        tensor = self.tensor.view(*args, **kwargs)
        res = ShareTensor(session=self.session)
        res.tensor = tensor
        return res

    __add__ = add
    __radd__ = add
    __sub__ = sub
    __rsub__ = rsub
    __mul__ = mul
    __rmul__ = mul
    __matmul__ = matmul
    __rmatmul__ = rmatmul
    __truediv__ = div
    __xor__ = xor
class ReplicatedSharedTensor(metaclass=SyMPCTensor):
    """RSTensor is used when a party holds more than a single share,required by various protocols.

    Arguments:
       shares (Optional[List[Union[float, int, torch.Tensor]]]): Shares list
           from which RSTensor is created.


    Attributes:
       shares: The shares held by the party
    """

    __slots__ = {
        # Populated in Syft
        "id",
        "tags",
        "description",
        "shares",
        "session_uuid",
        "config",
        "fp_encoder",
        "ring_size",
    }

    # Used by the SyMPCTensor metaclass
    METHODS_FORWARD = {
        "numel", "t", "unsqueeze", "view", "sum", "clone", "repeat"
    }
    PROPERTIES_FORWARD = {"T", "shape"}

    def __init__(
        self,
        shares: Optional[List[Union[float, int, torch.Tensor]]] = None,
        config: Config = Config(encoder_base=2, encoder_precision=16),
        session_uuid: Optional[UUID] = None,
        ring_size: int = 2**64,
    ):
        """Initialize ReplicatedSharedTensor.

        Args:
            shares (Optional[List[Union[float, int, torch.Tensor]]]): Shares list
                from which RSTensor is created.
            config (Config): The configuration where we keep the encoder precision and base.
            session_uuid (Optional[UUID]): Used to keep track of a share that is associated with a
                remote session
            ring_size (int): field used for the operations applied on the shares
                Defaults to 2**64

        """
        self.session_uuid = session_uuid
        self.ring_size = ring_size

        if ring_size in {2, PRIME_NUMBER}:
            self.config = Config(encoder_base=1, encoder_precision=0)
        else:
            self.config = config

        self.fp_encoder = FixedPointEncoder(
            base=self.config.encoder_base,
            precision=self.config.encoder_precision)

        tensor_type = get_type_from_ring(ring_size)

        self.shares = []

        if shares is not None:
            self.shares = [
                self._encode(share).to(tensor_type) for share in shares
            ]

    def _encode(self, data: torch.Tensor) -> torch.Tensor:
        """Encode via FixedPointEncoder.

        Args:
            data (torch.Tensor): Tensor to be encoded

        Returns:
            encoded_data (torch.Tensor): Decoded values

        """
        return self.fp_encoder.encode(data)

    def decode(self) -> List[torch.Tensor]:
        """Decode via FixedPointEncoder.

        Returns:
            List[torch.Tensor]: Decoded values
        """
        return self._decode()

    def _decode(self) -> List[torch.Tensor]:
        """Decodes shares list of RSTensor via FixedPointEncoder.

        Returns:
            List[torch.Tensor]: Decoded values
        """
        shares = []

        shares = [
            self.fp_encoder.decode(share.type(torch.LongTensor))
            for share in self.shares
        ]
        return shares

    def get_shares(self) -> List[torch.Tensor]:
        """Get shares.

        Returns:
            List[torch.Tensor]: List of shares.
        """
        return self.shares

    def get_ring_size(self) -> str:
        """Ring size of tensor.

        Returns:
            ring_size (str): Returns ring_size of tensor in string.

        It is typecasted to string as we cannot serialize 2**64
        """
        return str(self.ring_size)

    def get_config(self) -> Dict:
        """Config of tensor.

        Returns:
            config (Dict): returns config of the tensor as dict.
        """
        return dataclasses.asdict(self.config)

    @staticmethod
    def addmodprime(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Computes addition(x+y) modulo PRIME_NUMBER constant.

        Args:
            x (torch.Tensor): input tensor
            y (torch.tensor): input tensor

        Returns:
            value (torch.Tensor): Result of the operation.

        Raises:
            ValueError : If either of the tensors datatype is not torch.uint8
        """
        if x.dtype != torch.uint8 or y.dtype != torch.uint8:
            raise ValueError(
                f"Both tensors x:{x.dtype} y:{y.dtype} should be of torch.uint8 dtype"
            )

        return (x + y) % PRIME_NUMBER

    @staticmethod
    def submodprime(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Computes subtraction(x-y) modulo PRIME_NUMBER constant.

        Args:
            x (torch.Tensor): input tensor
            y (torch.tensor): input tensor

        Returns:
            value (torch.Tensor): Result of the operation.

        Raises:
            ValueError : If either of the tensors datatype is not torch.uint8
        """
        if x.dtype != torch.uint8 or y.dtype != torch.uint8:
            raise ValueError(
                f"Both tensors x:{x.dtype} y:{y.dtype} should be of torch.uint8 dtype"
            )

        # Typecasting is done, as underflow returns a positive number,as it is unsigned.
        x = x.to(torch.int8)
        y = y.to(torch.int8)

        result = (x - y) % PRIME_NUMBER

        return result.to(torch.uint8)

    @staticmethod
    def mulmodprime(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Computes multiplication(x*y) modulo PRIME_NUMBER constant.

        Args:
            x (torch.Tensor): input tensor
            y (torch.tensor): input tensor

        Returns:
            value (torch.Tensor): Result of the operation.

        Raises:
            ValueError : If either of the tensors datatype is not torch.uint8
        """
        if x.dtype != torch.uint8 or y.dtype != torch.uint8:
            raise ValueError(
                f"Both tensors x:{x.dtype} y:{y.dtype} should be of torch.uint8 dtype"
            )

        # We typecast as multiplication result in 2n bits ,which causes overflow.
        x = x.to(torch.int16)
        y = y.to(torch.int16)

        result = (x * y) % PRIME_NUMBER

        return result.to(torch.uint8)

    @staticmethod
    def get_op(ring_size: int, op_str: str) -> Callable[..., Any]:
        """Returns method attribute based on ring_size and op_str.

        Args:
            ring_size (int): Ring size
            op_str (str): Operation string.

        Returns:
            op (Callable[...,Any]): The operation method for the op_str.

        Raises:
            ValueError : If invalid ring size is given as input.
        """
        op = None
        if ring_size == 2:
            op = getattr(operator, BINARY_MAP[op_str])
        elif ring_size == PRIME_NUMBER:
            op = getattr(ReplicatedSharedTensor, op_str + "modprime")
        elif ring_size in RING_SIZE_TO_TYPE.keys():
            op = getattr(operator, op_str)
        else:
            raise ValueError(f"Invalid ring size: {ring_size}")

        return op

    @staticmethod
    def sanity_checks(
        x: "ReplicatedSharedTensor",
        y: Union[int, float, torch.Tensor, "ReplicatedSharedTensor"],
    ) -> "ReplicatedSharedTensor":
        """Check the type of "y" and convert it to share if necessary.

        Args:
            x (ReplicatedSharedTensor): Typically "self".
            y (Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]): Tensor to check.

        Returns:
            ReplicatedSharedTensor: the converted y value.

        Raises:
            ValueError: if both values are shares and they have different uuids
            ValueError: if both values have different number of shares.
            ValueError: if both RSTensor have different ring_sizes
        """
        if not isinstance(y, ReplicatedSharedTensor):
            # As prime ring size is unsigned,we convert negative values.
            y = y % PRIME_NUMBER if x.ring_size == PRIME_NUMBER else y

            y = ReplicatedSharedTensor(
                session_uuid=x.session_uuid,
                shares=[y],
                ring_size=x.ring_size,
                config=x.config,
            )

        elif y.session_uuid and x.session_uuid and y.session_uuid != x.session_uuid:
            raise ValueError(
                f"Session UUIDs did not match {x.session_uuid} {y.session_uuid}"
            )
        elif len(x.shares) != len(y.shares):
            raise ValueError(
                f"Both RSTensors should have equal number of shares {len(x.shares)} {len(y.shares)}"
            )
        elif x.ring_size != y.ring_size:
            raise ValueError(
                f"Both RSTensors should have same ring_size {x.ring_size} {y.ring_size}"
            )

        session_uuid = x.session_uuid

        if session_uuid is not None:
            session = sympc.session.get_session(str(x.session_uuid))
        else:
            session = Session(config=x.config, ring_size=x.ring_size)
            session.nr_parties = 1

        return y, session

    def __apply_public_op(self, y: Union[torch.Tensor, float, int],
                          op_str: str) -> "ReplicatedSharedTensor":
        """Apply an operation on "self" which is a RSTensor and a public value.

        Args:
            y (Union[torch.Tensor, float, int]): Tensor to apply the operation.
            op_str (str): The operation.

        Returns:
            ReplicatedSharedTensor: The operation "op_str" applied on "self" and "y"

        Raises:
            ValueError: If "op_str" is not supported.
        """
        y, session = ReplicatedSharedTensor.sanity_checks(self, y)

        op = ReplicatedSharedTensor.get_op(self.ring_size, op_str)

        shares = copy.deepcopy(self.shares)
        if op_str in {"add", "sub"}:
            if session.rank != 1:
                idx = (session.nr_parties - session.rank) % session.nr_parties
                shares[idx] = op(shares[idx], y.shares[0])
        else:
            raise ValueError(f"{op_str} not supported")

        result = ReplicatedSharedTensor(
            ring_size=self.ring_size,
            session_uuid=self.session_uuid,
            config=self.config,
        )
        result.shares = shares
        return result

    def __apply_private_op(self, y: "ReplicatedSharedTensor",
                           op_str: str) -> "ReplicatedSharedTensor":
        """Apply an operation on 2 RSTensors (secret shared values).

        Args:
            y (RelicatedSharedTensor): Tensor to apply the operation
            op_str (str): The operation

        Returns:
            ReplicatedSharedTensor: The operation "op_str" applied on "self" and "y"

        Raises:
            ValueError: If "op_str" not supported.
        """
        y, session = ReplicatedSharedTensor.sanity_checks(self, y)

        op = ReplicatedSharedTensor.get_op(self.ring_size, op_str)

        shares = []
        if op_str in {"add", "sub"}:
            for x_share, y_share in zip(self.shares, y.shares):
                shares.append(op(x_share, y_share))
        else:
            raise ValueError(f"{op_str} not supported")

        result = ReplicatedSharedTensor(
            ring_size=self.ring_size,
            session_uuid=self.session_uuid,
            config=self.config,
        )
        result.shares = shares
        return result

    def __apply_op(
        self,
        y: Union["ReplicatedSharedTensor", torch.Tensor, float, int],
        op_str: str,
    ) -> "ReplicatedSharedTensor":
        """Apply a given operation ".

         This function checks if "y" is private or public value.

        Args:
            y (Union[ReplicatedSharedTensor,torch.Tensor, float, int]): tensor
                to apply the operation.
            op_str (str): the operation.

        Returns:
            ReplicatedSharedTensor: the operation "op_str" applied on "self" and "y"
        """
        is_private = isinstance(y, ReplicatedSharedTensor)

        if is_private:
            result = self.__apply_private_op(y, op_str)
        else:
            result = self.__apply_public_op(y, op_str)

        return result

    def add(
        self, y: Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]
    ) -> "ReplicatedSharedTensor":
        """Apply the "add" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]): self + y

        Returns:
            ReplicatedSharedTensor: Result of the operation.
        """
        return self.__apply_op(y, "add")

    def sub(
        self, y: Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]
    ) -> "ReplicatedSharedTensor":
        """Apply the "sub" operation between  "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]): self - y

        Returns:
            ReplicatedSharedTensor: Result of the operation.
        """
        return self.__apply_op(y, "sub")

    def rsub(
        self, y: Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]
    ) -> "ReplicatedSharedTensor":
        """Apply the "sub" operation between "y" and "self".

        Args:
            y (Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]): y -self

        Returns:
            ReplicatedSharedTensor: Result of the operation.
        """
        return self.__apply_op(y, "sub")

    def mul(
        self, y: Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]
    ) -> "ReplicatedSharedTensor":
        """Apply the "mul" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]): self*y

        Returns:
            ReplicatedSharedTensor: Result of the operation.

        Raises:
            ValueError: Raised when private mul is performed parties!=3.

        """
        y_tensor, session = self.sanity_checks(self, y)
        is_private = isinstance(y, ReplicatedSharedTensor)

        op_str = "mul"
        op = ReplicatedSharedTensor.get_op(self.ring_size, op_str)
        if is_private:
            if session.nr_parties == 3:
                from sympc.protocol import Falcon

                result = [
                    Falcon.multiplication_protocol(self, y_tensor, op_str)
                ]
            else:
                raise ValueError(
                    "Private mult between ReplicatedSharedTensors is allowed only for 3 parties"
                )
        else:
            result = [op(share, y_tensor.shares[0]) for share in self.shares]

        tensor = ReplicatedSharedTensor(ring_size=self.ring_size,
                                        session_uuid=self.session_uuid,
                                        config=self.config)
        tensor.shares = result

        return tensor

    def matmul(
        self, y: Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]
    ) -> "ReplicatedSharedTensor":
        """Apply the "matmul" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]): self@y

        Returns:
            ReplicatedSharedTensor: Result of the operation.

        Raises:
            ValueError: Raised when private matmul is performed parties!=3.

        """
        y_tensor, session = self.sanity_checks(self, y)
        is_private = isinstance(y, ReplicatedSharedTensor)

        op_str = "matmul"

        if is_private:
            if session.nr_parties == 3:
                from sympc.protocol import Falcon

                result = [
                    Falcon.multiplication_protocol(self, y_tensor, op_str)
                ]
            else:
                raise ValueError(
                    "Private matmul between ReplicatedSharedTensors is allowed only for 3 parties"
                )
        else:
            result = [
                operator.matmul(share, y_tensor.shares[0])
                for share in self.shares
            ]

        tensor = ReplicatedSharedTensor(ring_size=self.ring_size,
                                        session_uuid=self.session_uuid,
                                        config=self.config)
        tensor.shares = result

        return tensor

    def truediv(self, y: Union[int, torch.Tensor]) -> "ReplicatedSharedTensor":
        """Apply the "div" operation between "self" and "y".

        Args:
            y (Union[int , torch.Tensor]): Denominator.

        Returns:
            ReplicatedSharedTensor: Result of the operation.

        Raises:
            ValueError: If y is not an integer or LongTensor.
        """
        if not isinstance(y, (int, torch.LongTensor)):
            raise ValueError(
                "Div works (for the moment) only with integers and LongTensor!"
            )

        res = ReplicatedSharedTensor(session_uuid=self.session_uuid,
                                     config=self.config,
                                     ring_size=self.ring_size)
        res.shares = [share // y for share in self.shares]
        return res

    def rshift(self, y: int) -> "ReplicatedSharedTensor":
        """Apply the "rshift" operation to "self".

        Args:
            y (int): shift value

        Returns:
            ReplicatedSharedTensor: Result of the operation.

        Raises:
            ValueError: If y is not an integer.
            ValueError : If invalid shift value is provided.
        """
        if not isinstance(y, int):
            raise ValueError("Right Shift works only with integers!")

        ring_bits = get_nr_bits(self.ring_size)
        if y < 0 or y > ring_bits - 1:
            raise ValueError(
                f"Invalid value for right shift: {y}, must be in range:[0,{ring_bits-1}]"
            )

        res = ReplicatedSharedTensor(session_uuid=self.session_uuid,
                                     config=self.config,
                                     ring_size=self.ring_size)
        res.shares = [share >> y for share in self.shares]
        return res

    def bit_extraction(self, pos: int = 0) -> "ReplicatedSharedTensor":
        """Extracts the bit at the specified position.

        Args:
            pos (int): position to extract bit.

        Returns:
            ReplicatedSharedTensor : extracted bits at specific position.

        Raises:
            ValueError: If invalid position is provided.
        """
        ring_bits = get_nr_bits(self.ring_size)
        if pos < 0 or pos > ring_bits - 1:
            raise ValueError(
                f"Invalid position for bit_extraction: {pos}, must be in range:[0,{ring_bits-1}]"
            )
        shares = []
        # logical shift
        bit_mask = torch.ones(self.shares[0].shape,
                              dtype=self.shares[0].dtype) << pos
        shares = [share & bit_mask for share in self.shares]
        rst = ReplicatedSharedTensor(
            shares=shares,
            session_uuid=self.session_uuid,
            config=Config(encoder_base=1, encoder_precision=0),
            ring_size=2,
        )
        return rst

    def rmatmul(self, y):
        """Apply the "rmatmul" operation between "y" and "self".

        Args:
            y: self@y

        Raises:
            NotImplementedError: Raised when implementation not present
        """
        raise NotImplementedError

    def xor(
        self, y: Union[int, torch.Tensor, "ReplicatedSharedTensor"]
    ) -> "ReplicatedSharedTensor":
        """Apply the "xor" operation between "self" and "y".

        Args:
            y: public bit

        Returns:
            ReplicatedSharedTensor: Result of the operation.

        Raises:
            ValueError : If ring size is invalid.
        """
        if self.ring_size == 2:
            return self + y
        elif self.ring_size in RING_SIZE_TO_TYPE:
            return self + y - (self * y * 2)
        else:
            raise ValueError(
                f"The ring_size {self.ring_size} is not supported.")

    def lt(self, y):
        """Lower than operator.

        Args:
            y: self<y

        Raises:
            NotImplementedError: Raised when implementation not present
        """
        raise NotImplementedError

    def gt(self, y):
        """Greater than operator.

        Args:
            y: self>y

        Raises:
            NotImplementedError: Raised when implementation not present
        """
        raise NotImplementedError

    def eq(self, y: Any) -> bool:
        """Equal operator.

        Check if "self" is equal with another object given a set of attributes to compare.

        Args:
            y (Any): Object to compare

        Returns:
            bool: True if equal False if not.
        """
        if not (torch.cat(self.shares) == torch.cat(y.shares)).all():
            return False

        if self.config != y.config:
            return False

        if self.session_uuid and y.session_uuid and self.session_uuid != y.session_uuid:
            return False

        if self.ring_size != y.ring_size:
            return False

        return True

    def __getitem__(self, key: int) -> torch.Tensor:
        """Allows to subset shares.

        Args:
            key (int): The share to be retrieved.

        Returns:
            share (torch.Tensor): Returned share.
        """
        return self.shares[key]

    def __setitem__(self, key: int, newvalue: torch.Tensor) -> None:
        """Allows to set share value to new value.

        Args:
            key (int): The share to be retrieved.
            newvalue (torch.Tensor): New value of share.

        """
        self.shares[key] = newvalue

    def ne(self, y):
        """Not Equal operator.

        Args:
            y: self!=y

        Raises:
            NotImplementedError: Raised when implementation not present
        """
        raise NotImplementedError

    @staticmethod
    def shares_sum(shares: List[torch.Tensor], ring_size: int) -> torch.Tensor:
        """Returns sum of tensors based on ring_size.

        Args:
            shares (List[torch.Tensor]) : List of tensors.
            ring_size (int): Ring size of share associated with the tensors.

        Returns:
            value (torch.Tensor): sum of the tensors.
        """
        if ring_size == 2:
            return reduce(lambda x, y: x ^ y, shares)
        elif ring_size == PRIME_NUMBER:
            return reduce(ReplicatedSharedTensor.addmodprime, shares)
        else:
            return sum(shares)

    @staticmethod
    def _request_and_get(
        share_ptr: "ReplicatedSharedTensor", ) -> "ReplicatedSharedTensor":
        """Function used to request and get a share - Duet Setup.

        Args:
            share_ptr (ReplicatedSharedTensor): input ReplicatedSharedTensor

        Returns:
            ReplicatedSharedTensor : The ReplicatedSharedTensor in local.
        """
        if not ispointer(share_ptr):
            return share_ptr
        elif not islocal(share_ptr):
            share_ptr.request(block=True)
        res = share_ptr.get_copy()
        return res

    @staticmethod
    def __reconstruct_semi_honest(
        share_ptrs: List["ReplicatedSharedTensor"],
        get_shares: bool = False,
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
        """Reconstruct value from shares.

        Args:
            share_ptrs (List[ReplicatedSharedTensor]): List of RSTensor pointers.
            get_shares (bool): Retrieve only shares.

        Returns:
            reconstructed_value (torch.Tensor): Reconstructed value.
        """
        request = ReplicatedSharedTensor._request_and_get
        request_wrap = parallel_execution(request)
        args = [[share] for share in share_ptrs[:2]]
        local_shares = request_wrap(args)

        shares = [local_shares[0].shares[0]]
        shares.extend(local_shares[1].shares)

        if get_shares:
            return shares

        ring_size = local_shares[0].ring_size

        return ReplicatedSharedTensor.shares_sum(shares, ring_size)

    @staticmethod
    def __reconstruct_malicious(
        share_ptrs: List["ReplicatedSharedTensor"],
        get_shares: bool = False,
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
        """Reconstruct value from shares.

        Args:
            share_ptrs (List[ReplicatedSharedTensor]): List of RSTensor pointers.
            get_shares (bool): Retrieve only shares.

        Returns:
            reconstructed_value (torch.Tensor): Reconstructed value.

        Raises:
            ValueError: When parties share values are not equal.
        """
        nparties = len(share_ptrs)

        # Get shares from all parties
        request = ReplicatedSharedTensor._request_and_get
        request_wrap = parallel_execution(request)
        args = [[share] for share in share_ptrs]
        local_shares = request_wrap(args)
        ring_size = local_shares[0].ring_size
        shares_sum = ReplicatedSharedTensor.shares_sum

        all_shares = [rst.shares for rst in local_shares]
        # reconstruct shares from all parties and verify
        value = None
        for party_rank in range(nparties):
            tensor = shares_sum(
                [all_shares[party_rank][0]] + all_shares[(party_rank + 1) %
                                                         (nparties)],
                ring_size,
            )

            if value is None:
                value = tensor
            elif (tensor != value).any():
                raise ValueError(
                    "Reconstruction values from all parties are not equal.")

        if get_shares:
            return all_shares

        return value

    @staticmethod
    def reconstruct(
        share_ptrs: List["ReplicatedSharedTensor"],
        get_shares: bool = False,
        security_type: str = "semi-honest",
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
        """Reconstruct value from shares.

        Args:
            share_ptrs (List[ReplicatedSharedTensor]): List of RSTensor pointers.
            security_type (str): Type of security followed by protocol.
            get_shares (bool): Retrieve only shares.

        Returns:
            reconstructed_value (torch.Tensor): Reconstructed value.

        Raises:
            ValueError: Invalid security type.
            ValueError : SharePointers not provided.
        """
        if not len(share_ptrs):
            raise ValueError(
                "Share pointers must be provided for reconstruction.")
        if security_type == "malicious":
            return ReplicatedSharedTensor.__reconstruct_malicious(
                share_ptrs, get_shares)

        elif security_type == "semi-honest":
            return ReplicatedSharedTensor.__reconstruct_semi_honest(
                share_ptrs, get_shares)

        raise ValueError("Invalid security Type")

    @staticmethod
    def distribute_shares_to_party(
        shares: List[Union[ShareTensor, torch.Tensor]],
        party_rank: int,
        session: Session,
        ring_size: int,
        config: Config,
    ) -> "ReplicatedSharedTensor":
        """Distributes shares to party.

        Args:
            shares (List[Union[ShareTensor,torch.Tensor]]): Shares to be distributed.
            party_rank (int): Rank of party.
            session (Session): Current session
            ring_size (int): Ring size of tensor to distribute
            config (Config): The configuration(base,precision) of the tensor.

        Returns:
            tensor (ReplicatedSharedTensor): Tensor with shares

        Raises:
            TypeError: Invalid share class.
        """
        party = session.parties[party_rank]
        nshares = session.nr_parties - 1
        party_shares = []

        for share_index in range(party_rank, party_rank + nshares):
            share = shares[share_index % (nshares + 1)]

            if isinstance(share, torch.Tensor):
                party_shares.append(share)

            elif isinstance(share, ShareTensor):
                party_shares.append(share.tensor)

            else:
                raise TypeError(f"{type(share)} is an invalid share class")

        tensor = ReplicatedSharedTensor(
            session_uuid=session.rank_to_uuid[party_rank],
            config=config,
            ring_size=ring_size,
        )
        tensor.shares = party_shares
        return tensor.send(party)

    @staticmethod
    def distribute_shares(
        shares: List[Union[ShareTensor, torch.Tensor]],
        session: Session,
        ring_size: Optional[int] = None,
        config: Optional[Config] = None,
    ) -> List["ReplicatedSharedTensor"]:
        """Distribute a list of shares.

        Args:
            shares (List[ShareTensor): list of shares to distribute.
            session (Session): Session.
            ring_size (int): ring_size the shares belong to.
            config (Config): The configuration(base,precision) of the tensor.

        Returns:
            List of ReplicatedSharedTensors.

        Raises:
            TypeError: when Datatype of shares is invalid.

        """
        if not isinstance(shares, (list, tuple)):
            raise TypeError(
                "Shares to be distributed should be a list of shares")

        if len(shares) != session.nr_parties:
            return ValueError(
                "Number of shares to be distributed should be same as number of parties"
            )

        if ring_size is None:
            ring_size = session.ring_size
        if config is None:
            config = session.config

        args = [[shares, party_rank, session, ring_size, config]
                for party_rank in range(session.nr_parties)]

        return [
            ReplicatedSharedTensor.distribute_shares_to_party(*arg)
            for arg in args
        ]

    @staticmethod
    def hook_property(property_name: str) -> Any:
        """Hook a framework property (only getter).

        Ex:
         * if we call "shape" we want to call it on the underlying tensor
        and return the result
         * if we call "T" we want to call it on the underlying tensor
        but we want to wrap it in the same tensor type

        Args:
            property_name (str): property to hook

        Returns:
            A hooked property
        """
        def property_new_rs_tensor_getter(
                _self: "ReplicatedSharedTensor") -> Any:
            shares = []

            for share in _self.shares:
                tensor = getattr(share, property_name)
                shares.append(tensor)

            res = ReplicatedSharedTensor(
                session_uuid=_self.session_uuid,
                config=_self.config,
                ring_size=_self.ring_size,
            )
            res.shares = shares
            return res

        def property_getter(_self: "ReplicatedSharedTensor") -> Any:
            prop = getattr(_self.shares[0], property_name)
            return prop

        if property_name in PROPERTIES_NEW_RS_TENSOR:
            res = property(property_new_rs_tensor_getter, None)
        else:
            res = property(property_getter, None)

        return res

    @staticmethod
    def hook_method(method_name: str) -> Callable[..., Any]:
        """Hook a framework method such that we know how to treat it given that we call it.

        Ex:
         * if we call "numel" we want to call it on the underlying tensor
        and return the result
         * if we call "unsqueeze" we want to call it on the underlying tensor
        but we want to wrap it in the same tensor type

        Args:
            method_name (str): method to hook

        Returns:
            A hooked method

        """
        def method_new_rs_tensor(_self: "ReplicatedSharedTensor",
                                 *args: List[Any], **kwargs: Dict[Any,
                                                                  Any]) -> Any:
            shares = []
            for share in _self.shares:
                tensor = getattr(share, method_name)(*args, **kwargs)
                shares.append(tensor)

            res = ReplicatedSharedTensor(
                session_uuid=_self.session_uuid,
                config=_self.config,
                ring_size=_self.ring_size,
            )
            res.shares = shares
            return res

        def method(_self: "ReplicatedSharedTensor", *args: List[Any],
                   **kwargs: Dict[Any, Any]) -> Any:
            method = getattr(_self.shares[0], method_name)
            res = method(*args, **kwargs)
            return res

        if method_name in METHODS_NEW_RS_TENSOR:
            res = method_new_rs_tensor
        else:
            res = method

        return res

    __add__ = add
    __radd__ = add
    __sub__ = sub
    __rsub__ = rsub
    __mul__ = mul
    __rmul__ = mul
    __matmul__ = matmul
    __rmatmul__ = rmatmul
    __truediv__ = truediv
    __floordiv__ = truediv
    __xor__ = xor
    __eq__ = eq
    __rshift__ = rshift
Esempio n. 7
0
class ShareTensor:
    """This class represents 1 share that a party holds when doing secret
    sharing.

    Arguments:
        data (Optional[Any]): the share a party holds
        session (Optional[Any]): the session from which this shares belongs to
        encoder_base (int): the base for the encoder
        encoder_precision (int): the precision for the encoder
        ring_size (int): field used for the operations applied on the shares

    Attributes:
        Syft Serializable Attributes

        id (UID): the id to store the session
        tags (Optional[List[str]): an optional list of strings that are tags used at search
        description (Optional[str]): an optional string used to describe the session

        tensor (Any): the value of the share
        session (Session): keep track from which session  this share belongs to
        fp_encoder (FixedPointEncoder): the encoder used to convert a share from/to fixed point
    """

    __slots__ = {
        # Populated in Syft
        "id",
        "tags",
        "description",
        "tensor",
        "session",
        "fp_encoder",
    }

    def __init__(
        self,
        data: Optional[Union[float, int, torch.Tensor]] = None,
        session: Optional[Session] = None,
        encoder_base: int = 2,
        encoder_precision: int = 16,
        ring_size: int = 2**64,
    ) -> None:
        """Initializer for the ShareTensor."""

        if session is None:
            self.session = Session(ring_size=ring_size, )
            self.session.config.encoder_precision = encoder_precision
            self.session.config.encoder_base = encoder_base

        else:
            self.session = session
            encoder_precision = self.session.config.encoder_precision
            encoder_base = self.session.config.encoder_base

        # TODO: It looks like the same logic as above
        self.fp_encoder = FixedPointEncoder(base=encoder_base,
                                            precision=encoder_precision)

        self.tensor: Optional[torch.Tensor] = None

        if data is not None:
            tensor_type = self.session.tensor_type
            self.tensor = self._encode(data).type(tensor_type)

    def _encode(self, data):
        return self.fp_encoder.encode(data)

    def decode(self):
        return self._decode()

    def _decode(self):
        return self.fp_encoder.decode(self.tensor.type(torch.LongTensor))

    @staticmethod
    def sanity_checks(x: "ShareTensor", y: Union[int, float, torch.Tensor,
                                                 "ShareTensor"],
                      op_str: str) -> "ShareTensor":
        """Check the type of "y" and convert it to a share if necessary.

        :return: the y value
        :rtype: ShareTensor, int or Integer type Tensor
        """
        if not isinstance(y, ShareTensor):
            y = ShareTensor(data=y, session=x.session)

        return y

    def apply_function(self, y: Union["ShareTensor", torch.Tensor, int, float],
                       op_str: str) -> "ShareTensor":
        """Apply a given operation.

        :return: the result of applying "op_str" on "self" and y
        :rtype: ShareTensor
        """
        op = getattr(operator, op_str)

        if isinstance(y, ShareTensor):
            value = op(self.tensor, y.tensor)
        else:
            value = op(self.tensor, y)

        res = ShareTensor(session=self.session)
        res.tensor = value
        return res

    def add(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "add" operation between "self" and "y".

        :return: self + y
        :rtype: ShareTensor
        """
        y_share = ShareTensor.sanity_checks(self, y, "add")
        res = self.apply_function(y_share, "add")
        return res

    def sub(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "sub" operation between "self" and "y".

        :return: self - y
        :rtype: ShareTensor
        """
        y_share = ShareTensor.sanity_checks(self, y, "sub")
        res = self.apply_function(y_share, "sub")
        return res

    def rsub(
            self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "sub" operation between "y" and "self".

        :return: y - self
        :rtype: ShareTensor
        """
        y_share = ShareTensor.sanity_checks(self, y, "sub")
        res = y_share.apply_function(self, "sub")
        return res

    def mul(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "mul" operation between "self" and "y".

        :return: self * y
        :rtype: ShareTensor
        """
        y = ShareTensor.sanity_checks(self, y, "mul")
        res = self.apply_function(y, "mul")

        if self.session.nr_parties == 0:
            # We are using a simple share without usig the MPCTensor
            # In case we used the MPCTensor - the division would have
            # been done in the protocol
            res.tensor = res.tensor // self.fp_encoder.scale

        return res

    def matmul(
            self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "matmul" operation between "self" and "y".

        :return: self @ y
        :rtype: ShareTensor
        """
        y = ShareTensor.sanity_checks(self, y, "matmul")
        res = self.apply_function(y, "matmul")

        if self.session.nr_parties == 0:
            # We are using a simple share without usig the MPCTensor
            # In case we used the MPCTensor - the division would have
            # been done in the protocol
            res.tensor = res.tensor // self.fp_encoder.scale

        return res

    def rmatmul(self, y: torch.Tensor) -> "ShareTensor":
        """Apply the reversed "matmul" operation between "self" and "y".

        :return: y @ self
        :rtype: ShareTensor
        """
        y = ShareTensor.sanity_checks(self, y, "matmul")
        return y.matmul(self)

    def div(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "div" operation between "self" and "y". Currently,
        NotImplemented.

        :return: self / y
        :rtype: ShareTensor
        """
        if not isinstance(y, (int, torch.LongTensor)):
            raise ValueError("Div works (for the moment) only with integers!")

        res = ShareTensor(session=self.session)
        res.tensor = self.tensor // y

        return res

    def __getattr__(self, attr_name: str) -> Any:
        """Get the attribute from the ShareTensor. If the attribute is not
        found at the ShareTensor level, the it would look for in the the
        "tensor".

        :return: the attribute value
        :rtype: Anything
        """
        # Default to some tensor specific attributes like
        # size, shape, etc.
        tensor = self.tensor
        return getattr(tensor, attr_name)

    def __gt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool:
        """Check if "self" is bigger than "y".

        :return: self > y
        :rtype: bool
        """
        y_share = ShareTensor.sanity_checks(self, y, "gt")
        res = self.tensor > y_share.tensor
        return res

    def __lt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool:
        """Check if "self" is less than "y".

        :return: self < y
        :rtype: bool
        """

        y_share = ShareTensor.sanity_checks(self, y, "lt")
        res = self.tensor < y_share.tensor
        return res

    def __str__(self) -> str:
        """Return the string representation of ShareTensor."""
        type_name = type(self).__name__
        out = f"[{type_name}]"
        out = f"{out}\n\t| {self.fp_encoder}"
        out = f"{out}\n\t| Data: {self.tensor}"

        return out

    def __repr__(self):
        return self.__str__()

    def __eq__(self, other: Any) -> bool:
        """Check if "self" is equal with another object given a set of
        attributes to compare.

        :return: if self and other are equal
        :rtype: bool
        """

        if not (self.tensor == other.tensor).all():
            return False

        if not (self.session == other.session):
            return False

        return True

    __add__ = add
    __radd__ = add
    __sub__ = sub
    __rsub__ = rsub
    __mul__ = mul
    __rmul__ = mul
    __matmul__ = matmul
    __rmatmul__ = rmatmul
    __truediv__ = div
Esempio n. 8
0
class ReplicatedSharedTensor(metaclass=SyMPCTensor):
    """RSTensor is used when a party holds more than a single share,required by various protocols.

    Arguments:
       shares (Optional[List[Union[float, int, torch.Tensor]]]): Shares list
           from which RSTensor is created.


    Attributes:
       shares: The shares held by the party
    """

    AUTOGRAD_IS_ON: bool = True

    # Used by the SyMPCTensor metaclass
    METHODS_FORWARD = {"numel", "t", "unsqueeze", "view", "sum", "clone"}
    PROPERTIES_FORWARD = {"T"}

    def __init__(
        self,
        shares: Optional[List[Union[float, int, torch.Tensor]]] = None,
        config: Config = Config(encoder_base=2, encoder_precision=16),
        session_uuid: Optional[UUID] = None,
        ring_size: int = 2**64,
    ):
        """Initialize ReplicatedSharedTensor.

        Args:
            shares (Optional[List[Union[float, int, torch.Tensor]]]): Shares list
                from which RSTensor is created.
            config (Config): The configuration where we keep the encoder precision and base.
            session_uuid (Optional[UUID]): Used to keep track of a share that is associated with a
                remote session
            ring_size (int): field used for the operations applied on the shares
                Defaults to 2**64

        """
        self.session_uuid = session_uuid
        self.ring_size = ring_size

        self.config = config
        self.fp_encoder = FixedPointEncoder(base=config.encoder_base,
                                            precision=config.encoder_precision)

        tensor_type = get_type_from_ring(ring_size)
        self.shares = []
        if shares is not None:
            for i in range(len(shares)):
                self.shares.append(self._encode(shares[i]).to(tensor_type))

    def _encode(self, data):
        return self.fp_encoder.encode(data)

    def decode(self):
        """Decode via FixedPointEncoder.

        Returns:
            List[torch.Tensor]: Decoded values
        """
        return self._decode()

    def _decode(self):
        """Decodes shares list of RSTensor via FixedPointEncoder.

        Returns:
            List[torch.Tensor]: Decoded values
        """
        shares = []
        for i in range(len(self.shares)):
            tensor = self.fp_encoder.decode(self.shares[i].type(
                torch.LongTensor))
            shares.append(tensor)

        return shares

    def add(self, y):
        """Apply the "add" operation between "self" and "y".

        Args:
            y: self+y

        """

    def sub(self, y):
        """Apply the "sub" operation between "self" and "y".

        Args:
            y: self-y


        """

    def rsub(self, y):
        """Apply the "sub" operation between "y" and "self".

        Args:
            y: self-y

        """

    def mul(self, y):
        """Apply the "mul" operation between "self" and "y".

        Args:
            y: self*y

        """

    def truediv(self, y):
        """Apply the "div" operation between "self" and "y".

        Args:
            y: self/y

        """

    def matmul(self, y):
        """Apply the "matmul" operation between "self" and "y".

        Args:
            y: self@y

        """

    def rmatmul(self, y):
        """Apply the "rmatmul" operation between "y" and "self".

        Args:
            y: self@y

        """

    def xor(self, y):
        """Apply the "xor" operation between "self" and "y".

        Args:
            y: self^y

        """

    def lt(self, y):
        """Lower than operator.

        Args:
            y: self<y

        """

    def gt(self, y):
        """Greater than operator.

        Args:
            y: self>y

        """

    def eq(self, y):
        """Equal operator.

        Args:
            y: self==y

        """

    def ne(self, y):
        """Not Equal operator.

        Args:
            y: self!=y

        """

    @staticmethod
    def hook_property(property_name: str) -> Any:
        """Hook a framework property (only getter).

        Ex:
         * if we call "shape" we want to call it on the underlying tensor
        and return the result
         * if we call "T" we want to call it on the underlying tensor
        but we want to wrap it in the same tensor type

        Args:
            property_name (str): property to hook

        Returns:
            A hooked property
        """
        def property_new_rs_tensor_getter(
                _self: "ReplicatedSharedTensor") -> Any:
            shares = []

            for i in range(len(_self.shares)):
                tensor = getattr(_self.shares[i], property_name)
                shares.append(tensor)

            res = ReplicatedSharedTensor(session_uuid=_self.session_uuid,
                                         config=_self.config)
            res.shares = shares
            return res

        def property_getter(_self: "ReplicatedSharedTensor") -> Any:
            prop = getattr(_self.shares[0], property_name)
            return prop

        if property_name in PROPERTIES_NEW_RS_TENSOR:
            res = property(property_new_rs_tensor_getter, None)
        else:
            res = property(property_getter, None)

        return res

    @staticmethod
    def hook_method(method_name: str) -> Callable[..., Any]:
        """Hook a framework method such that we know how to treat it given that we call it.

        Ex:
         * if we call "numel" we want to call it on the underlying tensor
        and return the result
         * if we call "unsqueeze" we want to call it on the underlying tensor
        but we want to wrap it in the same tensor type

        Args:
            method_name (str): method to hook

        Returns:
            A hooked method

        """
        def method_new_rs_tensor(_self: "ReplicatedSharedTensor",
                                 *args: List[Any], **kwargs: Dict[Any,
                                                                  Any]) -> Any:
            shares = []
            for i in range(len(_self.shares)):
                tensor = getattr(_self.shares[i], method_name)(*args, **kwargs)
                shares.append(tensor)

            res = ReplicatedSharedTensor(session_uuid=_self.session_uuid,
                                         config=_self.config)
            res.shares = shares
            return res

        def method(_self: "ReplicatedSharedTensor", *args: List[Any],
                   **kwargs: Dict[Any, Any]) -> Any:
            method = getattr(_self.shares[0], method_name)
            res = method(*args, **kwargs)
            return res

        if method_name in METHODS_NEW_RS_TENSOR:
            res = method_new_rs_tensor
        else:
            res = method

        return res

    __add__ = add
    __radd__ = add
    __sub__ = sub
    __rsub__ = rsub
    __mul__ = mul
    __rmul__ = mul
    __matmul__ = matmul
    __rmatmul__ = rmatmul
    __truediv__ = truediv
    __xor__ = xor
Esempio n. 9
0
class ShareTensor(metaclass=SyMPCTensor):
    """Single Share representation.

    Arguments:
        data (Optional[Any]): the share a party holds
        session (Optional[Any]): the session from which this shares belongs to
        encoder_base (int): the base for the encoder
        encoder_precision (int): the precision for the encoder
        ring_size (int): field used for the operations applied on the shares

    Attributes:
        Syft Serializable Attributes

        id (UID): the id to store the session
        tags (Optional[List[str]): an optional list of strings that are tags used at search
        description (Optional[str]): an optional string used to describe the session

        tensor (Any): the value of the share
        session (Session): keep track from which session  this share belongs to
        fp_encoder (FixedPointEncoder): the encoder used to convert a share from/to fixed point
    """

    __slots__ = {
        # Populated in Syft
        "id",
        "tags",
        "description",
        "tensor",
        "session",
        "fp_encoder",
    }

    # Used by the SyMPCTensor metaclass
    METHODS_FORWARD: Set[str] = {
        "numel",
        "unsqueeze",
        "t",
        "view",
        "sum",
        "clone",
        "flatten",
        "reshape",
    }
    PROPERTIES_FORWARD: Set[str] = {"T", "shape"}

    def __init__(
        self,
        data: Optional[Union[float, int, torch.Tensor]] = None,
        session: Optional[Session] = None,
        encoder_base: int = 2,
        encoder_precision: int = 16,
        ring_size: int = 2**64,
    ) -> None:
        """Initialize ShareTensor.

        Args:
            data (Optional[Any]): The share a party holds. Defaults to None
            session (Optional[Any]): The session from which this shares belongs to.
                Defaults to None.
            encoder_base (int): The base for the encoder. Defaults to 2.
            encoder_precision (int): the precision for the encoder. Defaults to 16.
            ring_size (int): field used for the operations applied on the shares
                Defaults to 2**64
        """
        if session is None:
            self.session = Session(ring_size=ring_size, )
            self.session.config.encoder_precision = encoder_precision
            self.session.config.encoder_base = encoder_base

        else:
            self.session = session
            encoder_precision = self.session.config.encoder_precision
            encoder_base = self.session.config.encoder_base

        # TODO: It looks like the same logic as above
        self.fp_encoder = FixedPointEncoder(base=encoder_base,
                                            precision=encoder_precision)

        self.tensor: Optional[torch.Tensor] = None
        if data is not None:
            tensor_type = self.session.tensor_type
            self.tensor = self._encode(data).type(tensor_type)

    def _encode(self, data):
        return self.fp_encoder.encode(data)

    def decode(self):
        """Decode via FixedPrecisionEncoder.

        Returns:
            torch.Tensor: Decoded value
        """
        return self._decode()

    def _decode(self):
        return self.fp_encoder.decode(self.tensor.type(torch.LongTensor))

    @staticmethod
    def sanity_checks(x: "ShareTensor", y: Union[int, float, torch.Tensor,
                                                 "ShareTensor"],
                      op_str: str) -> "ShareTensor":
        """Check the type of "y" and covert it to share if necessary.

        Args:
            x (ShareTensor): Typically "self".
            y (Union[int, float, torch.Tensor, "ShareTensor"]): Tensor to check.
            op_str (str): String operator.

        Returns:
            ShareTensor: the converted y value.

        """
        if not isinstance(y, ShareTensor):
            y = ShareTensor(data=y, session=x.session)

        return y

    def apply_function(self, y: Union["ShareTensor", torch.Tensor, int, float],
                       op_str: str) -> "ShareTensor":
        """Apply a given operation.

        Args:
            y (Union["ShareTensor", torch.Tensor, int, float]): tensor to apply the operator.
            op_str (str): Operator.

        Returns:
            ShareTensor: Result of the operation.
        """
        op = getattr(operator, op_str)

        if isinstance(y, ShareTensor):
            value = op(self.tensor, y.tensor)
        else:
            value = op(self.tensor, y)

        res = ShareTensor(session=self.session)
        res.tensor = value
        return res

    def add(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "add" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): self + y

        Returns:
            ShareTensor. Result of the operation.
        """
        y_share = ShareTensor.sanity_checks(self, y, "add")
        res = self.apply_function(y_share, "add")
        return res

    def sub(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "sub" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): self - y

        Returns:
            ShareTensor. Result of the operation.
        """
        y_share = ShareTensor.sanity_checks(self, y, "sub")
        res = self.apply_function(y_share, "sub")
        return res

    def rsub(
            self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "sub" operation between "y" and "self".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): y - self

        Returns:
            ShareTensor. Result of the operation.
        """
        y_share = ShareTensor.sanity_checks(self, y, "sub")
        res = y_share.apply_function(self, "sub")
        return res

    def mul(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "mul" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): self * y

        Returns:
            ShareTensor. Result of the operation.
        """
        y = ShareTensor.sanity_checks(self, y, "mul")
        res = self.apply_function(y, "mul")

        if self.session.nr_parties == 0:
            # We are using a simple share without usig the MPCTensor
            # In case we used the MPCTensor - the division would have
            # been done in the protocol
            res.tensor //= self.fp_encoder.scale

        return res

    def xor(self, y: Union[int, torch.Tensor, "ShareTensor"]) -> "ShareTensor":
        """Apply the "xor" operation between "self" and "y".

        Args:
            y (Union[int, torch.Tensor, "ShareTensor"]): self xor y

        Returns:
            ShareTensor: Result of the operation.
        """
        res = ShareTensor(session=self.session)

        if isinstance(y, ShareTensor):
            res.tensor = self.tensor ^ y.tensor
        else:
            res.tensor = self.tensor ^ y

        return res

    def matmul(
            self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "matmul" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): self @ y.

        Returns:
            ShareTensor: Result of the operation.
        """
        y = ShareTensor.sanity_checks(self, y, "matmul")
        res = self.apply_function(y, "matmul")

        if self.session.nr_parties == 0:
            # We are using a simple share without usig the MPCTensor
            # In case we used the MPCTensor - the division would have
            # been done in the protocol
            res.tensor = res.tensor // self.fp_encoder.scale

        return res

    def rmatmul(self, y: torch.Tensor) -> "ShareTensor":
        """Apply the "rmatmul" operation between "y" and "self".

        Args:
            y (torch.Tensor): y @ self

        Returns:
            ShareTensor. Result of the operation.
        """
        y = ShareTensor.sanity_checks(self, y, "matmul")
        return y.matmul(self)

    def div(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "div" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): Denominator.

        Returns:
            ShareTensor: Result of the operation.

        Raises:
            ValueError: If y is not an integer or LongTensor.
        """
        if not isinstance(y, (int, torch.LongTensor)):
            raise ValueError("Div works (for the moment) only with integers!")

        res = ShareTensor(session=self.session)
        res.tensor = self.tensor // y

        return res

    def __gt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool:
        """Greater than operator.

        Args:
            y (Union["ShareTensor", torch.Tensor, int]): Tensor to compare.

        Returns:
            bool: Result of the comparison.
        """
        y_share = ShareTensor.sanity_checks(self, y, "gt")
        res = self.tensor > y_share.tensor
        return res

    def __lt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool:
        """Lower than operator.

        Args:
            y (Union["ShareTensor", torch.Tensor, int]): Tensor to compare.

        Returns:
            bool: Result of the comparison.
        """
        y_share = ShareTensor.sanity_checks(self, y, "lt")
        res = self.tensor < y_share.tensor
        return res

    def __str__(self) -> str:
        """Representation.

        Returns:
            str: Return the string representation of ShareTensor.
        """
        type_name = type(self).__name__
        out = f"[{type_name}]"
        out = f"{out}\n\t| {self.fp_encoder}"
        out = f"{out}\n\t| Data: {self.tensor}"

        return out

    def __repr__(self) -> str:
        """Representation.

        Returns:
            String representation.
        """
        return self.__str__()

    def __eq__(self, other: Any) -> bool:
        """Equal operator.

        Check if "self" is equal with another object given a set of
            attributes to compare.

        Args:
            other (Any): Tensor to compare.

        Returns:
            bool: True if equal False if not.

        """
        if not (self.tensor == other.tensor).all():
            return False

        if not (self.session == other.session):
            return False

        return True

    @staticmethod
    def hook_property(property_name: str) -> Any:
        """Hook a framework property (only getter).

        Ex:
         * if we call "shape" we want to call it on the underlying tensor
        and return the result
         * if we call "T" we want to call it on the underlying tensor
        but we want to wrap it in the same tensor type

        Args:
            property_name (str): property to hook

        Returns:
            A hooked property
        """
        def property_new_share_tensor_getter(_self: "ShareTensor") -> Any:
            tensor = getattr(_self.tensor, property_name)
            res = ShareTensor(session=_self.session)
            res.tensor = tensor
            return res

        def property_getter(_self: "ShareTensor") -> Any:
            prop = getattr(_self.tensor, property_name)
            return prop

        if property_name in PROPERTIES_NEW_SHARE_TENSOR:
            res = property(property_new_share_tensor_getter, None)
        else:
            res = property(property_getter, None)

        return res

    @staticmethod
    def hook_method(method_name: str) -> Callable[..., Any]:
        """Hook a framework method such that we know how to treat it given that we call it.

        Ex:
         * if we call "numel" we want to call it on the underlying tensor
        and return the result
         * if we call "unsqueeze" we want to call it on the underlying tensor
        but we want to wrap it in the same tensor type

        Args:
            method_name (str): method to hook

        Returns:
            A hooked method
        """
        def method_new_share_tensor(_self: "ShareTensor", *args: List[Any],
                                    **kwargs: Dict[Any, Any]) -> Any:
            method = getattr(_self.tensor, method_name)
            tensor = method(*args, **kwargs)
            res = ShareTensor(session=_self.session)
            res.tensor = tensor
            return res

        def method(_self: "ShareTensor", *args: List[Any],
                   **kwargs: Dict[Any, Any]) -> Any:
            method = getattr(_self.tensor, method_name)
            res = method(*args, **kwargs)
            return res

        if method_name in METHODS_NEW_SHARE_TENSOR:
            res = method_new_share_tensor
        else:
            res = method

        return res

    __add__ = add
    __radd__ = add
    __sub__ = sub
    __rsub__ = rsub
    __mul__ = mul
    __rmul__ = mul
    __matmul__ = matmul
    __rmatmul__ = rmatmul
    __truediv__ = div
    __xor__ = xor