Ejemplo n.º 1
0
    def reconstruct(
        self, decode: bool = True, get_shares: bool = False
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
        """Reconstruct the secret.

        Request and get the shares from all the parties and reconstruct the
        secret. Depending on the value of "decode", the secret would be decoded
        or not using the FixedPrecision Encoder specific for the session.

        Args:
            decode (bool): True if decode using FixedPointEncoder. Defaults to True
            get_shares (bool): Retrieve only shares.

        Returns:
            torch.Tensor. The secret reconstructed.
        """
        result = self.session.protocol.share_class.reconstruct(
            self.share_ptrs,
            get_shares=get_shares,
            security_type=self.session.protocol.security_type,
        )

        if get_shares:

            return result

        if decode:
            fp_encoder = FixedPointEncoder(
                base=self.session.config.encoder_base,
                precision=self.session.config.encoder_precision,
            )

            result = fp_encoder.decode(result)

        return result
Ejemplo n.º 2
0
    def reconstruct(
            self,
            decode: bool = True,
            get_shares: bool = False
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
        """Request and get the shares from all the parties and reconstruct the
        secret. Depending on the value of "decode", the secret would be decoded
        or not using the FixedPrecision Encoder specific for the session.

        Args:
            decode (bool): True if decode using FixedPointEncoder. Defaults to True
            get_shares (boot): True if get shares. Defaults to False.

        Returns:
            torch.Tensor. The secret reconstructed.
        """
        def _request_and_get(share_ptr: ShareTensor) -> ShareTensor:
            """Function used to request and get a share - Duet Setup

            Args:
                share_ptr (ShareTensor): a ShareTensor

            Returns:
                ShareTensor. The ShareTensor in local.

            """

            if not islocal(share_ptr):
                share_ptr.request(name="reconstruct", block=True)
            res = share_ptr.get_copy()
            return res

        request = _request_and_get

        request_wrap = parallel_execution(request)

        args = [[share] for share in self.share_ptrs]
        local_shares = request_wrap(args)

        tensor_type = self.session.tensor_type

        shares = [share.tensor for share in local_shares]

        if get_shares:
            return shares

        plaintext = sum(shares)

        if decode:
            fp_encoder = FixedPointEncoder(
                base=self.session.config.encoder_base,
                precision=self.session.config.encoder_precision,
            )

            plaintext = fp_encoder.decode(plaintext)

        return plaintext
Ejemplo n.º 3
0
def test_fp_decoding():
    """Test correct decoding with FixedPointEncoder."""
    fp_encoder = FixedPointEncoder()
    # Test decoding with tensors.
    # CaseA: throw a ValueError with floating point tensors.
    tensor = torch.Tensor([1.00, 2.00, 3.00])
    with pytest.raises(ValueError):
        fp_encoder.decode(tensor)
    # Test decoding with ints.
    # Case A (precision = 0):
    fp_encoder = FixedPointEncoder(precision=0)
    test_int = 2
    decoded_int = fp_encoder.decode(test_int)
    target_int = torch.LongTensor([2])
    assert (decoded_int == target_int).all()
    # Case B (precision != 0):
    fp_encoder = FixedPointEncoder()
    test_int = 2 * fp_encoder.base**fp_encoder.precision
    decoded_int = fp_encoder.decode(test_int)
    target_int = torch.LongTensor([2])
    assert (decoded_int == target_int).all()
Ejemplo n.º 4
0
    def lt(self, other: "MPCTensor") -> "MPCTensor":
        """Lower than operator.

        Args:
            other (MPCTensor): MPCTensor to compare.

        Returns:
            MPCTensor: Result of the comparison.
        """
        protocol = self.session.get_protocol()
        other = self.__check_or_convert(other, self.session)
        fp_encoder = FixedPointEncoder(
            base=self.session.config.encoder_base,
            precision=self.session.config.encoder_precision,
        )
        one = fp_encoder.decode(1)
        return protocol.le(self + one, other)
def test_fixed_point(precision, base) -> None:
    x = torch.tensor([1.25, 3.301])
    shares = [x, x]
    rst = ReplicatedSharedTensor(shares=shares,
                                 config=Config(encoder_precision=precision,
                                               encoder_base=base))
    fp_encoder = FixedPointEncoder(precision=precision, base=base)
    tensor_type = get_type_from_ring(rst.ring_size)
    for i in range(len(shares)):
        shares[i] = fp_encoder.encode(shares[i]).to(tensor_type)

    assert (torch.cat(shares) == torch.cat(rst.shares)).all()

    for i in range(len(shares)):
        shares[i] = fp_encoder.decode(shares[i].type(torch.LongTensor))

    assert (torch.cat(shares) == torch.cat(rst.decode())).all()
Ejemplo n.º 6
0
def test_truncation_algorithm1(get_clients, base, precision) -> None:
    parties = get_clients(3)
    falcon = Falcon("semi-honest")
    config = Config(encoder_base=base, encoder_precision=precision)
    session = Session(parties=parties, protocol=falcon, config=config)
    SessionManager.setup_mpc(session)

    x = torch.tensor([[1.24, 4.51, 6.87], [7.87, 1301, 541]])

    x_mpc = MPCTensor(secret=x, session=session)

    result = ABY3.truncate(x_mpc, session, session.ring_size, session.config)

    fp_encoder = FixedPointEncoder(base=session.config.encoder_base,
                                   precision=session.config.encoder_precision)
    expected_res = x_mpc.reconstruct(decode=False) // fp_encoder.scale
    expected_res = fp_encoder.decode(expected_res)

    assert np.allclose(result.reconstruct(), expected_res, atol=1e-3)
Ejemplo n.º 7
0
    def reconstruct(self, decode: bool = True) -> torch.Tensor:
        """Request and get the shares from all the parties and reconstruct the secret.
        Depending on the value of "decode", the secret would be decoded or not using
        the FixedPrecision Encoder specific for the session

        :return: the secret reconstructed
        :rtype: tensor
        """
        def _request_and_get(share_ptr: ShareTensor) -> ShareTensor:
            """Function used to request and get a share - Duet Setup
            :return: the ShareTensor (local)
            :rtype: ShareTensor
            """

            if not islocal(share_ptr):
                share_ptr.request(name="reconstruct", block=True)
            res = share_ptr.get_copy()
            return res

        request = _request_and_get

        request_wrap = parallel_execution(request)

        args = [[share] for share in self.share_ptrs]
        local_shares = request_wrap(args)

        tensor_type = self.session.tensor_type

        plaintext = sum(share.tensor for share in local_shares)

        if decode:
            fp_encoder = FixedPointEncoder(
                base=self.session.config.encoder_base,
                precision=self.session.config.encoder_precision,
            )

            plaintext = fp_encoder.decode(plaintext)

        return plaintext
Ejemplo n.º 8
0
class ShareTensor(metaclass=SyMPCTensor):
    """Single Share representation.

    Arguments:
        data (Optional[Any]): the share a party holds
        session (Optional[Any]): the session from which this shares belongs to
        encoder_base (int): the base for the encoder
        encoder_precision (int): the precision for the encoder
        ring_size (int): field used for the operations applied on the shares

    Attributes:
        Syft Serializable Attributes

        id (UID): the id to store the session
        tags (Optional[List[str]): an optional list of strings that are tags used at search
        description (Optional[str]): an optional string used to describe the session

        tensor (Any): the value of the share
        session_uuid (Optional[UUID]): keep track from which session the share belongs to
        encoder_precision (int): precision for the encoder
        encoder_base (int): base for the encoder
    """

    __slots__ = {
        # Populated in Syft
        "id",
        "tags",
        "description",
        "tensor",
        "session_uuid",
        "config",
        "fp_encoder",
        "ring_size",
    }

    # Used by the SyMPCTensor metaclass
    METHODS_FORWARD: Set[str] = {
        "numel",
        "squeeze",
        "unsqueeze",
        "t",
        "view",
        "expand",
        "sum",
        "clone",
        "flatten",
        "reshape",
        "repeat",
        "narrow",
        "dim",
        "transpose",
        "roll",
    }
    PROPERTIES_FORWARD: Set[str] = {"T", "shape"}

    def __init__(
        self,
        data: Optional[Union[float, int, torch.Tensor]] = None,
        config: Config = Config(encoder_base=2, encoder_precision=16),
        session_uuid: Optional[UUID] = None,
        ring_size: int = 2**64,
    ) -> None:
        """Initialize ShareTensor.

        Args:
            data (Optional[Any]): The share a party holds. Defaults to None
            config (Config): The configuration where we keep the encoder precision and base.
            session_uuid (Optional[UUID]): Used to keep track of a share that is associated with a
                remote session
            ring_size (int): field used for the operations applied on the shares
                Defaults to 2**64
        """
        self.session_uuid = session_uuid
        self.ring_size = ring_size

        self.config = config
        self.fp_encoder = FixedPointEncoder(base=config.encoder_base,
                                            precision=config.encoder_precision)

        self.tensor: Optional[torch.Tensor] = None
        if data is not None:
            tensor_type = get_type_from_ring(ring_size)
            self.tensor = self._encode(data).to(tensor_type)

    def _encode(self, data):
        return self.fp_encoder.encode(data)

    def decode(self):
        """Decode via FixedPrecisionEncoder.

        Returns:
            torch.Tensor: Decoded value
        """
        return self._decode()

    def _decode(self):
        return self.fp_encoder.decode(self.tensor.type(torch.LongTensor))

    @staticmethod
    def sanity_checks(x: "ShareTensor", y: Union[int, float, torch.Tensor,
                                                 "ShareTensor"],
                      op_str: str) -> "ShareTensor":
        """Check the type of "y" and covert it to share if necessary.

        Args:
            x (ShareTensor): Typically "self".
            y (Union[int, float, torch.Tensor, "ShareTensor"]): Tensor to check.
            op_str (str): String operator.

        Returns:
            ShareTensor: the converted y value.

        Raises:
            ValueError: if both values are shares and they have different uuids
        """
        if not isinstance(y, ShareTensor):
            if x.session_uuid is not None:
                session = sympc.session.get_session(str(x.session_uuid))
                ring_size = session.ring_size
                config = session.config
            else:
                ring_size = x.ring_size
                config = x.config

            y = ShareTensor(data=y, ring_size=ring_size, config=config)

        elif y.session_uuid and x.session_uuid and y.session_uuid != x.session_uuid:
            raise ValueError(
                f"Session UUIDs did not match {x.session_uuid} {y.session_uuid}"
            )

        return y

    def apply_function(self, y: Union["ShareTensor", torch.Tensor, int, float],
                       op_str: str) -> "ShareTensor":
        """Apply a given operation.

        Args:
            y (Union["ShareTensor", torch.Tensor, int, float]): tensor to apply the operator.
            op_str (str): Operator.

        Returns:
            ShareTensor: Result of the operation.
        """
        op = getattr(operator, op_str)

        if isinstance(y, ShareTensor):
            value = op(self.tensor, y.tensor)
        else:
            value = op(self.tensor, y)

        session_uuid = self.session_uuid or y.session_uuid
        if session_uuid is not None:
            session = sympc.session.get_session(str(session_uuid))
            ring_size = session.ring_size
            config = session.config
        else:
            # Use the values from "self"
            ring_size = self.ring_size
            config = self.config

        res = ShareTensor(ring_size=ring_size,
                          session_uuid=session_uuid,
                          config=config)
        res.tensor = value
        return res

    def add(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "add" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): self + y

        Returns:
            ShareTensor. Result of the operation.
        """
        y_share = ShareTensor.sanity_checks(self, y, "add")
        res = self.apply_function(y_share, "add")
        return res

    def sub(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "sub" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): self - y

        Returns:
            ShareTensor. Result of the operation.
        """
        y_share = ShareTensor.sanity_checks(self, y, "sub")
        res = self.apply_function(y_share, "sub")
        return res

    def rsub(
            self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "sub" operation between "y" and "self".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): y - self

        Returns:
            ShareTensor. Result of the operation.
        """
        y_share = ShareTensor.sanity_checks(self, y, "sub")
        res = y_share.apply_function(self, "sub")
        return res

    def mul(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "mul" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): self * y

        Returns:
            ShareTensor. Result of the operation.
        """
        y = ShareTensor.sanity_checks(self, y, "mul")
        res = self.apply_function(y, "mul")

        if self.session_uuid is None and y.session_uuid is None:
            # We are using a simple share without usig the MPCTensor
            # In case we used the MPCTensor - the division would have
            # been done in the protocol
            res.tensor //= self.fp_encoder.scale

        return res

    def xor(self, y: Union[int, torch.Tensor, "ShareTensor"]) -> "ShareTensor":
        """Apply the "xor" operation between "self" and "y".

        Args:
            y (Union[int, torch.Tensor, "ShareTensor"]): self xor y

        Returns:
            ShareTensor: Result of the operation.
        """
        res = self.apply_function(y, "xor")
        return res

    def matmul(
            self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "matmul" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): self @ y.

        Returns:
            ShareTensor: Result of the operation.
        """
        y = ShareTensor.sanity_checks(self, y, "matmul")
        res = self.apply_function(y, "matmul")

        if self.session_uuid is None and y.session_uuid is None:
            # We are using a simple share without usig the MPCTensor
            # In case we used the MPCTensor - the division would have
            # been done in the protocol
            res.tensor = res.tensor // self.fp_encoder.scale

        return res

    def rmatmul(self, y: torch.Tensor) -> "ShareTensor":
        """Apply the "rmatmul" operation between "y" and "self".

        Args:
            y (torch.Tensor): y @ self

        Returns:
            ShareTensor. Result of the operation.
        """
        y = ShareTensor.sanity_checks(self, y, "matmul")
        return y.matmul(self)

    def div(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "div" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): Denominator.

        Returns:
            ShareTensor: Result of the operation.

        Raises:
            ValueError: If y is not an integer or LongTensor.
        """
        if not isinstance(y, (int, torch.LongTensor)):
            raise ValueError("Div works (for the moment) only with integers!")

        res = ShareTensor(session_uuid=self.session_uuid, config=self.config)
        # res = self.apply_function(y, "floordiv")
        res.tensor = self.tensor // y
        return res

    def __gt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool:
        """Greater than operator.

        Args:
            y (Union["ShareTensor", torch.Tensor, int]): Tensor to compare.

        Returns:
            bool: Result of the comparison.
        """
        y_share = ShareTensor.sanity_checks(self, y, "gt")
        res = self.tensor > y_share.tensor
        return res

    def __lt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool:
        """Lower than operator.

        Args:
            y (Union["ShareTensor", torch.Tensor, int]): Tensor to compare.

        Returns:
            bool: Result of the comparison.
        """
        y_share = ShareTensor.sanity_checks(self, y, "lt")
        res = self.tensor < y_share.tensor
        return res

    def __str__(self) -> str:
        """Representation.

        Returns:
            str: Return the string representation of ShareTensor.
        """
        type_name = type(self).__name__
        out = f"[{type_name}]"
        out = f"{out}\n\t| Session UUID: {self.session_uuid}"
        out = f"{out}\n\t| {self.fp_encoder}"
        out = f"{out}\n\t| Data: {self.tensor}"

        return out

    def __repr__(self) -> str:
        """Representation.

        Returns:
            String representation.
        """
        return self.__str__()

    def __eq__(self, other: Any) -> bool:
        """Equal operator.

        Check if "self" is equal with another object given a set of
            attributes to compare.

        Args:
            other (Any): Tensor to compare.

        Returns:
            bool: True if equal False if not.

        """
        if not (self.tensor == other.tensor).all():
            return False

        if not self.config == other.config:
            return False

        if (self.session_uuid and other.session_uuid
                and self.session_uuid != other.session_uuid):
            # If both shares have a session_uuid we consider them not equal
            # else they are
            return False

        return True

    @staticmethod
    def hook_property(property_name: str) -> Any:
        """Hook a framework property (only getter).

        Ex:
         * if we call "shape" we want to call it on the underlying tensor
        and return the result
         * if we call "T" we want to call it on the underlying tensor
        but we want to wrap it in the same tensor type

        Args:
            property_name (str): property to hook

        Returns:
            A hooked property
        """
        def property_new_share_tensor_getter(_self: "ShareTensor") -> Any:
            tensor = getattr(_self.tensor, property_name)
            res = ShareTensor(session_uuid=_self.session_uuid,
                              config=_self.config)
            res.tensor = tensor
            return res

        def property_getter(_self: "ShareTensor") -> Any:
            prop = getattr(_self.tensor, property_name)
            return prop

        if property_name in PROPERTIES_NEW_SHARE_TENSOR:
            res = property(property_new_share_tensor_getter, None)
        else:
            res = property(property_getter, None)

        return res

    @staticmethod
    def hook_method(method_name: str) -> Callable[..., Any]:
        """Hook a framework method such that we know how to treat it given that we call it.

        Ex:
         * if we call "numel" we want to call it on the underlying tensor
        and return the result
         * if we call "unsqueeze" we want to call it on the underlying tensor
        but we want to wrap it in the same tensor type

        Args:
            method_name (str): method to hook

        Returns:
            A hooked method
        """
        def method_new_share_tensor(_self: "ShareTensor", *args: List[Any],
                                    **kwargs: Dict[Any, Any]) -> Any:
            method = getattr(_self.tensor, method_name)
            tensor = method(*args, **kwargs)
            res = ShareTensor(session_uuid=_self.session_uuid,
                              config=_self.config)
            res.tensor = tensor
            return res

        def method(_self: "ShareTensor", *args: List[Any],
                   **kwargs: Dict[Any, Any]) -> Any:
            method = getattr(_self.tensor, method_name)
            res = method(*args, **kwargs)
            return res

        if method_name in METHODS_NEW_SHARE_TENSOR:
            res = method_new_share_tensor
        else:
            res = method

        return res

    @staticmethod
    def reconstruct(
        share_ptrs: List["ShareTensor"],
        get_shares=False,
        security_type: str = "semi-honest",
    ) -> torch.Tensor:
        """Reconstruct original value from shares.

        Args:
            share_ptrs (List[ShareTensor]): List of sharetensors.
            get_shares (boolean): retrieve shares or reconstructed value.
            security_type (str): Type of security by protocol.

        Returns:
            plaintext/shares (torch.Tensor/List[torch.Tensors]): Plaintext or list of shares.

        """
        def _request_and_get(share_ptr: ShareTensor) -> ShareTensor:
            """Function used to request and get a share - Duet Setup.

            Args:
                share_ptr (ShareTensor): a ShareTensor

            Returns:
                ShareTensor. The ShareTensor in local.

            """
            if not islocal(share_ptr):
                share_ptr.request(block=True)
            res = share_ptr.get_copy()
            return res

        request = _request_and_get
        request_wrap = parallel_execution(request)

        args = [[share] for share in share_ptrs]
        local_shares = request_wrap(args)

        shares = [share.tensor for share in local_shares]

        if get_shares:
            return shares

        plaintext = sum(shares)

        return plaintext

    @staticmethod
    def distribute_shares(shares: List["ShareTensor"], session: Session):
        """Distribute a list of shares.

        Args:
            shares (List[ShareTensor): list of shares to distribute.
            session (Session): Session for which those shares were generated

        Returns:
            List of ShareTensorPointers.
        """
        rank_to_uuid = session.rank_to_uuid
        parties = session.parties

        share_ptrs = []
        for rank, share in enumerate(shares):
            share.session_uuid = rank_to_uuid[rank]
            party = parties[rank]
            share_ptrs.append(share.send(party))

        return share_ptrs

    __add__ = add
    __radd__ = add
    __sub__ = sub
    __rsub__ = rsub
    __mul__ = mul
    __rmul__ = mul
    __matmul__ = matmul
    __rmatmul__ = rmatmul
    __truediv__ = div
    __xor__ = xor
Ejemplo n.º 9
0
class ShareTensor:
    """This class represents 1 share that a party holds when doing secret sharing.

    Attributes:
        Syft Serializable Attributes

        id (UID): the id to store the session
        tags (Optional[List[str]): an optional list of strings that are tags used at search
        description (Optional[str]): an optional string used to describe the session

        tensor (Any): the value of the share
        session (Session): keep track from which session  this share belongs to
        fp_encoder (FixedPointEncoder): the encoder used to convert a share from/to fixed point
    """

    __slots__ = {
        # Populated in Syft
        "id",
        "tags",
        "description",
        "tensor",
        "session",
        "fp_encoder",
    }

    def __init__(
        self,
        data: Optional[Union[float, int, torch.Tensor]] = None,
        session: Optional[Session] = None,
        encoder_base: int = 2,
        encoder_precision: int = 16,
        ring_size: int = 2**64,
    ) -> None:
        """Initialize ShareTensor.

        Args:
            data (Optional[Any]): The share a party holds. Defaults to None
            session (Optional[Any]): The session from which this shares belongs to.
                Defaults to None.
            encoder_base (int): The base for the encoder. Defaults to 2.
            encoder_precision (int): the precision for the encoder. Defaults to 16.
            ring_size (int): field used for the operations applied on the shares
                Defaults to 2**64
        """
        if session is None:
            self.session = Session(ring_size=ring_size, )
            self.session.config.encoder_precision = encoder_precision
            self.session.config.encoder_base = encoder_base

        else:
            self.session = session
            encoder_precision = self.session.config.encoder_precision
            encoder_base = self.session.config.encoder_base

        # TODO: It looks like the same logic as above
        self.fp_encoder = FixedPointEncoder(base=encoder_base,
                                            precision=encoder_precision)

        self.tensor: Optional[torch.Tensor] = None

        if data is not None:
            tensor_type = self.session.tensor_type
            self.tensor = self._encode(data).type(tensor_type)

    def _encode(self, data):
        return self.fp_encoder.encode(data)

    def decode(self):
        """Decode via FixedPrecisionEncoder.

        Returns:
            torch.Tensor: Decoded value
        """
        return self._decode()

    def _decode(self):
        return self.fp_encoder.decode(self.tensor.type(torch.LongTensor))

    @staticmethod
    def sanity_checks(x: "ShareTensor", y: Union[int, float, torch.Tensor,
                                                 "ShareTensor"],
                      op_str: str) -> "ShareTensor":
        """Check the type of "y" and covert it to share if necessary.

        Args:
            x (ShareTensor): Typically "self".
            y (Union[int, float, torch.Tensor, "ShareTensor"]): Tensor to check.
            op_str (str): String operator.

        Returns:
            ShareTensor: the converted y value.

        """
        if not isinstance(y, ShareTensor):
            y = ShareTensor(data=y, session=x.session)

        return y

    def apply_function(self, y: Union["ShareTensor", torch.Tensor, int, float],
                       op_str: str) -> "ShareTensor":
        """Apply a given operation.

        Args:
            y (Union["ShareTensor", torch.Tensor, int, float]): tensor to apply the operator.
            op_str (str): Operator.

        Returns:
            ShareTensor: Result of the operation.
        """
        op = getattr(operator, op_str)

        if isinstance(y, ShareTensor):
            value = op(self.tensor, y.tensor)
        else:
            value = op(self.tensor, y)

        res = ShareTensor(session=self.session)
        res.tensor = value
        return res

    def add(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "add" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): self + y

        Returns:
            ShareTensor. Result of the operation.
        """
        y_share = ShareTensor.sanity_checks(self, y, "add")
        res = self.apply_function(y_share, "add")
        return res

    def sub(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "sub" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): self - y

        Returns:
            ShareTensor. Result of the operation.
        """
        y_share = ShareTensor.sanity_checks(self, y, "sub")
        res = self.apply_function(y_share, "sub")
        return res

    def rsub(
            self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "sub" operation between "y" and "self".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): y - self

        Returns:
            ShareTensor. Result of the operation.
        """
        y_share = ShareTensor.sanity_checks(self, y, "sub")
        res = y_share.apply_function(self, "sub")
        return res

    def mul(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "mul" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): self * y

        Returns:
            ShareTensor. Result of the operation.
        """
        y = ShareTensor.sanity_checks(self, y, "mul")
        res = self.apply_function(y, "mul")

        if self.session.nr_parties == 0:
            # We are using a simple share without usig the MPCTensor
            # In case we used the MPCTensor - the division would have
            # been done in the protocol
            res.tensor = res.tensor // self.fp_encoder.scale

        return res

    def xor(self, y: Union[int, torch.Tensor, "ShareTensor"]) -> "ShareTensor":
        """Apply the "xor" operation between "self" and "y".

        Args:
            y (Union[int, torch.Tensor, "ShareTensor"]): self xor y

        Returns:
            ShareTensor: Result of the operation.
        """
        res = ShareTensor(session=self.session)

        if isinstance(y, ShareTensor):
            res.tensor = self.tensor ^ y.tensor
        else:
            res.tensor = self.tensor ^ y

        return res

    def matmul(
            self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "matmul" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): self @ y.

        Returns:
            ShareTensor: Result of the operation.
        """
        y = ShareTensor.sanity_checks(self, y, "matmul")
        res = self.apply_function(y, "matmul")

        if self.session.nr_parties == 0:
            # We are using a simple share without usig the MPCTensor
            # In case we used the MPCTensor - the division would have
            # been done in the protocol
            res.tensor = res.tensor // self.fp_encoder.scale

        return res

    def rmatmul(self, y: torch.Tensor) -> "ShareTensor":
        """Apply the "rmatmul" operation between "y" and "self".

        Args:
            y (torch.Tensor): y @ self

        Returns:
            ShareTensor. Result of the operation.
        """
        y = ShareTensor.sanity_checks(self, y, "matmul")
        return y.matmul(self)

    def div(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "div" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): Denominator.

        Returns:
            ShareTensor: Result of the operation.

        Raises:
            ValueError: If y is not an integer or LongTensor.
        """
        if not isinstance(y, (int, torch.LongTensor)):
            raise ValueError("Div works (for the moment) only with integers!")

        res = ShareTensor(session=self.session)
        res.tensor = self.tensor // y

        return res

    def __getattr__(self, attr_name: str) -> Any:
        """Access to tensor attributes.

        Args:
            attr_name (str): Name of the attribute.

        Returns:
            Any: Attribute.
        """
        tensor = self.tensor
        res = getattr(tensor, attr_name)
        return res

    def __gt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool:
        """Greater than operator.

        Args:
            y (Union["ShareTensor", torch.Tensor, int]): Tensor to compare.

        Returns:
            bool: Result of the comparison.
        """
        y_share = ShareTensor.sanity_checks(self, y, "gt")
        res = self.tensor > y_share.tensor
        return res

    def __lt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool:
        """Lower than operator.

        Args:
            y (Union["ShareTensor", torch.Tensor, int]): Tensor to compare.

        Returns:
            bool: Result of the comparison.
        """
        y_share = ShareTensor.sanity_checks(self, y, "lt")
        res = self.tensor < y_share.tensor
        return res

    def __str__(self) -> str:
        """Representation.

        Returns:
            str: Return the string representation of ShareTensor.
        """
        type_name = type(self).__name__
        out = f"[{type_name}]"
        out = f"{out}\n\t| {self.fp_encoder}"
        out = f"{out}\n\t| Data: {self.tensor}"

        return out

    def __repr__(self) -> str:
        """Representation.

        Returns:
            String representation.
        """
        return self.__str__()

    def __eq__(self, other: Any) -> bool:
        """Equal operator.

        Check if "self" is equal with another object given a set of
            attributes to compare.

        Args:
            other (Any): Tensor to compare.

        Returns:
            bool: True if equal False if not.

        """
        if not (self.tensor == other.tensor).all():
            return False

        if not (self.session == other.session):
            return False

        return True

    # Forward to tensor methods

    @property
    def shape(self) -> Any:
        """Shape of the tensor.

        Returns:
            Any: Shape of the tensor.
        """
        return self.tensor.shape

    def numel(self, *args: List[Any], **kwargs: Dict[Any, Any]) -> Any:
        """Total number of elements.

        Args:
            *args: Arguments passed to tensor.numel.
            **kwargs: Keyword arguments passed to tensor.numel.

        Returns:
            Any: Total number of elements of the tensor.

        """
        return self.tensor.numel(*args, **kwargs)

    @property
    def T(self) -> Any:
        """Transpose.

        Returns:
            Any: ShareTensor transposed.

        """
        res = ShareTensor(session=self.session)
        res.tensor = self.tensor.T
        return res

    def unsqueeze(self, *args: List[Any], **kwargs: Dict[Any, Any]) -> Any:
        """Tensor with a dimension of size one inserted at the specified position.

        Args:
            *args: Arguments to tensor.unsqueeze
            **kwargs: Keyword arguments passed to tensor.unsqueeze

        Returns:
            Any: ShareTensor unsqueezed.

        References:
            https://pytorch.org/docs/stable/generated/torch.unsqueeze.html
        """
        tensor = self.tensor.unsqueeze(*args, **kwargs)
        res = ShareTensor(session=self.session)
        res.tensor = tensor
        return res

    def view(self, *args: List[Any], **kwargs: Dict[Any, Any]) -> Any:
        """Tensor with the same data but new dimensions/view.

        Args:
            *args: Arguments to tensor.view.
            **kwargs: Keyword arguments passed to tensor.view.

        Returns:
            Any: ShareTensor with new view.

        References:
            https://pytorch.org/docs/stable/tensors.html#torch.Tensor.view
        """
        tensor = self.tensor.view(*args, **kwargs)
        res = ShareTensor(session=self.session)
        res.tensor = tensor
        return res

    __add__ = add
    __radd__ = add
    __sub__ = sub
    __rsub__ = rsub
    __mul__ = mul
    __rmul__ = mul
    __matmul__ = matmul
    __rmatmul__ = rmatmul
    __truediv__ = div
    __xor__ = xor
Ejemplo n.º 10
0
class ReplicatedSharedTensor(metaclass=SyMPCTensor):
    """RSTensor is used when a party holds more than a single share,required by various protocols.

    Arguments:
       shares (Optional[List[Union[float, int, torch.Tensor]]]): Shares list
           from which RSTensor is created.


    Attributes:
       shares: The shares held by the party
    """

    __slots__ = {
        # Populated in Syft
        "id",
        "tags",
        "description",
        "shares",
        "session_uuid",
        "config",
        "fp_encoder",
        "ring_size",
    }

    # Used by the SyMPCTensor metaclass
    METHODS_FORWARD = {
        "numel", "t", "unsqueeze", "view", "sum", "clone", "repeat"
    }
    PROPERTIES_FORWARD = {"T", "shape"}

    def __init__(
        self,
        shares: Optional[List[Union[float, int, torch.Tensor]]] = None,
        config: Config = Config(encoder_base=2, encoder_precision=16),
        session_uuid: Optional[UUID] = None,
        ring_size: int = 2**64,
    ):
        """Initialize ReplicatedSharedTensor.

        Args:
            shares (Optional[List[Union[float, int, torch.Tensor]]]): Shares list
                from which RSTensor is created.
            config (Config): The configuration where we keep the encoder precision and base.
            session_uuid (Optional[UUID]): Used to keep track of a share that is associated with a
                remote session
            ring_size (int): field used for the operations applied on the shares
                Defaults to 2**64

        """
        self.session_uuid = session_uuid
        self.ring_size = ring_size

        if ring_size in {2, PRIME_NUMBER}:
            self.config = Config(encoder_base=1, encoder_precision=0)
        else:
            self.config = config

        self.fp_encoder = FixedPointEncoder(
            base=self.config.encoder_base,
            precision=self.config.encoder_precision)

        tensor_type = get_type_from_ring(ring_size)

        self.shares = []

        if shares is not None:
            self.shares = [
                self._encode(share).to(tensor_type) for share in shares
            ]

    def _encode(self, data: torch.Tensor) -> torch.Tensor:
        """Encode via FixedPointEncoder.

        Args:
            data (torch.Tensor): Tensor to be encoded

        Returns:
            encoded_data (torch.Tensor): Decoded values

        """
        return self.fp_encoder.encode(data)

    def decode(self) -> List[torch.Tensor]:
        """Decode via FixedPointEncoder.

        Returns:
            List[torch.Tensor]: Decoded values
        """
        return self._decode()

    def _decode(self) -> List[torch.Tensor]:
        """Decodes shares list of RSTensor via FixedPointEncoder.

        Returns:
            List[torch.Tensor]: Decoded values
        """
        shares = []

        shares = [
            self.fp_encoder.decode(share.type(torch.LongTensor))
            for share in self.shares
        ]
        return shares

    def get_shares(self) -> List[torch.Tensor]:
        """Get shares.

        Returns:
            List[torch.Tensor]: List of shares.
        """
        return self.shares

    def get_ring_size(self) -> str:
        """Ring size of tensor.

        Returns:
            ring_size (str): Returns ring_size of tensor in string.

        It is typecasted to string as we cannot serialize 2**64
        """
        return str(self.ring_size)

    def get_config(self) -> Dict:
        """Config of tensor.

        Returns:
            config (Dict): returns config of the tensor as dict.
        """
        return dataclasses.asdict(self.config)

    @staticmethod
    def addmodprime(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Computes addition(x+y) modulo PRIME_NUMBER constant.

        Args:
            x (torch.Tensor): input tensor
            y (torch.tensor): input tensor

        Returns:
            value (torch.Tensor): Result of the operation.

        Raises:
            ValueError : If either of the tensors datatype is not torch.uint8
        """
        if x.dtype != torch.uint8 or y.dtype != torch.uint8:
            raise ValueError(
                f"Both tensors x:{x.dtype} y:{y.dtype} should be of torch.uint8 dtype"
            )

        return (x + y) % PRIME_NUMBER

    @staticmethod
    def submodprime(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Computes subtraction(x-y) modulo PRIME_NUMBER constant.

        Args:
            x (torch.Tensor): input tensor
            y (torch.tensor): input tensor

        Returns:
            value (torch.Tensor): Result of the operation.

        Raises:
            ValueError : If either of the tensors datatype is not torch.uint8
        """
        if x.dtype != torch.uint8 or y.dtype != torch.uint8:
            raise ValueError(
                f"Both tensors x:{x.dtype} y:{y.dtype} should be of torch.uint8 dtype"
            )

        # Typecasting is done, as underflow returns a positive number,as it is unsigned.
        x = x.to(torch.int8)
        y = y.to(torch.int8)

        result = (x - y) % PRIME_NUMBER

        return result.to(torch.uint8)

    @staticmethod
    def mulmodprime(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Computes multiplication(x*y) modulo PRIME_NUMBER constant.

        Args:
            x (torch.Tensor): input tensor
            y (torch.tensor): input tensor

        Returns:
            value (torch.Tensor): Result of the operation.

        Raises:
            ValueError : If either of the tensors datatype is not torch.uint8
        """
        if x.dtype != torch.uint8 or y.dtype != torch.uint8:
            raise ValueError(
                f"Both tensors x:{x.dtype} y:{y.dtype} should be of torch.uint8 dtype"
            )

        # We typecast as multiplication result in 2n bits ,which causes overflow.
        x = x.to(torch.int16)
        y = y.to(torch.int16)

        result = (x * y) % PRIME_NUMBER

        return result.to(torch.uint8)

    @staticmethod
    def get_op(ring_size: int, op_str: str) -> Callable[..., Any]:
        """Returns method attribute based on ring_size and op_str.

        Args:
            ring_size (int): Ring size
            op_str (str): Operation string.

        Returns:
            op (Callable[...,Any]): The operation method for the op_str.

        Raises:
            ValueError : If invalid ring size is given as input.
        """
        op = None
        if ring_size == 2:
            op = getattr(operator, BINARY_MAP[op_str])
        elif ring_size == PRIME_NUMBER:
            op = getattr(ReplicatedSharedTensor, op_str + "modprime")
        elif ring_size in RING_SIZE_TO_TYPE.keys():
            op = getattr(operator, op_str)
        else:
            raise ValueError(f"Invalid ring size: {ring_size}")

        return op

    @staticmethod
    def sanity_checks(
        x: "ReplicatedSharedTensor",
        y: Union[int, float, torch.Tensor, "ReplicatedSharedTensor"],
    ) -> "ReplicatedSharedTensor":
        """Check the type of "y" and convert it to share if necessary.

        Args:
            x (ReplicatedSharedTensor): Typically "self".
            y (Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]): Tensor to check.

        Returns:
            ReplicatedSharedTensor: the converted y value.

        Raises:
            ValueError: if both values are shares and they have different uuids
            ValueError: if both values have different number of shares.
            ValueError: if both RSTensor have different ring_sizes
        """
        if not isinstance(y, ReplicatedSharedTensor):
            # As prime ring size is unsigned,we convert negative values.
            y = y % PRIME_NUMBER if x.ring_size == PRIME_NUMBER else y

            y = ReplicatedSharedTensor(
                session_uuid=x.session_uuid,
                shares=[y],
                ring_size=x.ring_size,
                config=x.config,
            )

        elif y.session_uuid and x.session_uuid and y.session_uuid != x.session_uuid:
            raise ValueError(
                f"Session UUIDs did not match {x.session_uuid} {y.session_uuid}"
            )
        elif len(x.shares) != len(y.shares):
            raise ValueError(
                f"Both RSTensors should have equal number of shares {len(x.shares)} {len(y.shares)}"
            )
        elif x.ring_size != y.ring_size:
            raise ValueError(
                f"Both RSTensors should have same ring_size {x.ring_size} {y.ring_size}"
            )

        session_uuid = x.session_uuid

        if session_uuid is not None:
            session = sympc.session.get_session(str(x.session_uuid))
        else:
            session = Session(config=x.config, ring_size=x.ring_size)
            session.nr_parties = 1

        return y, session

    def __apply_public_op(self, y: Union[torch.Tensor, float, int],
                          op_str: str) -> "ReplicatedSharedTensor":
        """Apply an operation on "self" which is a RSTensor and a public value.

        Args:
            y (Union[torch.Tensor, float, int]): Tensor to apply the operation.
            op_str (str): The operation.

        Returns:
            ReplicatedSharedTensor: The operation "op_str" applied on "self" and "y"

        Raises:
            ValueError: If "op_str" is not supported.
        """
        y, session = ReplicatedSharedTensor.sanity_checks(self, y)

        op = ReplicatedSharedTensor.get_op(self.ring_size, op_str)

        shares = copy.deepcopy(self.shares)
        if op_str in {"add", "sub"}:
            if session.rank != 1:
                idx = (session.nr_parties - session.rank) % session.nr_parties
                shares[idx] = op(shares[idx], y.shares[0])
        else:
            raise ValueError(f"{op_str} not supported")

        result = ReplicatedSharedTensor(
            ring_size=self.ring_size,
            session_uuid=self.session_uuid,
            config=self.config,
        )
        result.shares = shares
        return result

    def __apply_private_op(self, y: "ReplicatedSharedTensor",
                           op_str: str) -> "ReplicatedSharedTensor":
        """Apply an operation on 2 RSTensors (secret shared values).

        Args:
            y (RelicatedSharedTensor): Tensor to apply the operation
            op_str (str): The operation

        Returns:
            ReplicatedSharedTensor: The operation "op_str" applied on "self" and "y"

        Raises:
            ValueError: If "op_str" not supported.
        """
        y, session = ReplicatedSharedTensor.sanity_checks(self, y)

        op = ReplicatedSharedTensor.get_op(self.ring_size, op_str)

        shares = []
        if op_str in {"add", "sub"}:
            for x_share, y_share in zip(self.shares, y.shares):
                shares.append(op(x_share, y_share))
        else:
            raise ValueError(f"{op_str} not supported")

        result = ReplicatedSharedTensor(
            ring_size=self.ring_size,
            session_uuid=self.session_uuid,
            config=self.config,
        )
        result.shares = shares
        return result

    def __apply_op(
        self,
        y: Union["ReplicatedSharedTensor", torch.Tensor, float, int],
        op_str: str,
    ) -> "ReplicatedSharedTensor":
        """Apply a given operation ".

         This function checks if "y" is private or public value.

        Args:
            y (Union[ReplicatedSharedTensor,torch.Tensor, float, int]): tensor
                to apply the operation.
            op_str (str): the operation.

        Returns:
            ReplicatedSharedTensor: the operation "op_str" applied on "self" and "y"
        """
        is_private = isinstance(y, ReplicatedSharedTensor)

        if is_private:
            result = self.__apply_private_op(y, op_str)
        else:
            result = self.__apply_public_op(y, op_str)

        return result

    def add(
        self, y: Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]
    ) -> "ReplicatedSharedTensor":
        """Apply the "add" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]): self + y

        Returns:
            ReplicatedSharedTensor: Result of the operation.
        """
        return self.__apply_op(y, "add")

    def sub(
        self, y: Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]
    ) -> "ReplicatedSharedTensor":
        """Apply the "sub" operation between  "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]): self - y

        Returns:
            ReplicatedSharedTensor: Result of the operation.
        """
        return self.__apply_op(y, "sub")

    def rsub(
        self, y: Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]
    ) -> "ReplicatedSharedTensor":
        """Apply the "sub" operation between "y" and "self".

        Args:
            y (Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]): y -self

        Returns:
            ReplicatedSharedTensor: Result of the operation.
        """
        return self.__apply_op(y, "sub")

    def mul(
        self, y: Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]
    ) -> "ReplicatedSharedTensor":
        """Apply the "mul" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]): self*y

        Returns:
            ReplicatedSharedTensor: Result of the operation.

        Raises:
            ValueError: Raised when private mul is performed parties!=3.

        """
        y_tensor, session = self.sanity_checks(self, y)
        is_private = isinstance(y, ReplicatedSharedTensor)

        op_str = "mul"
        op = ReplicatedSharedTensor.get_op(self.ring_size, op_str)
        if is_private:
            if session.nr_parties == 3:
                from sympc.protocol import Falcon

                result = [
                    Falcon.multiplication_protocol(self, y_tensor, op_str)
                ]
            else:
                raise ValueError(
                    "Private mult between ReplicatedSharedTensors is allowed only for 3 parties"
                )
        else:
            result = [op(share, y_tensor.shares[0]) for share in self.shares]

        tensor = ReplicatedSharedTensor(ring_size=self.ring_size,
                                        session_uuid=self.session_uuid,
                                        config=self.config)
        tensor.shares = result

        return tensor

    def matmul(
        self, y: Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]
    ) -> "ReplicatedSharedTensor":
        """Apply the "matmul" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ReplicatedSharedTensor"]): self@y

        Returns:
            ReplicatedSharedTensor: Result of the operation.

        Raises:
            ValueError: Raised when private matmul is performed parties!=3.

        """
        y_tensor, session = self.sanity_checks(self, y)
        is_private = isinstance(y, ReplicatedSharedTensor)

        op_str = "matmul"

        if is_private:
            if session.nr_parties == 3:
                from sympc.protocol import Falcon

                result = [
                    Falcon.multiplication_protocol(self, y_tensor, op_str)
                ]
            else:
                raise ValueError(
                    "Private matmul between ReplicatedSharedTensors is allowed only for 3 parties"
                )
        else:
            result = [
                operator.matmul(share, y_tensor.shares[0])
                for share in self.shares
            ]

        tensor = ReplicatedSharedTensor(ring_size=self.ring_size,
                                        session_uuid=self.session_uuid,
                                        config=self.config)
        tensor.shares = result

        return tensor

    def truediv(self, y: Union[int, torch.Tensor]) -> "ReplicatedSharedTensor":
        """Apply the "div" operation between "self" and "y".

        Args:
            y (Union[int , torch.Tensor]): Denominator.

        Returns:
            ReplicatedSharedTensor: Result of the operation.

        Raises:
            ValueError: If y is not an integer or LongTensor.
        """
        if not isinstance(y, (int, torch.LongTensor)):
            raise ValueError(
                "Div works (for the moment) only with integers and LongTensor!"
            )

        res = ReplicatedSharedTensor(session_uuid=self.session_uuid,
                                     config=self.config,
                                     ring_size=self.ring_size)
        res.shares = [share // y for share in self.shares]
        return res

    def rshift(self, y: int) -> "ReplicatedSharedTensor":
        """Apply the "rshift" operation to "self".

        Args:
            y (int): shift value

        Returns:
            ReplicatedSharedTensor: Result of the operation.

        Raises:
            ValueError: If y is not an integer.
            ValueError : If invalid shift value is provided.
        """
        if not isinstance(y, int):
            raise ValueError("Right Shift works only with integers!")

        ring_bits = get_nr_bits(self.ring_size)
        if y < 0 or y > ring_bits - 1:
            raise ValueError(
                f"Invalid value for right shift: {y}, must be in range:[0,{ring_bits-1}]"
            )

        res = ReplicatedSharedTensor(session_uuid=self.session_uuid,
                                     config=self.config,
                                     ring_size=self.ring_size)
        res.shares = [share >> y for share in self.shares]
        return res

    def bit_extraction(self, pos: int = 0) -> "ReplicatedSharedTensor":
        """Extracts the bit at the specified position.

        Args:
            pos (int): position to extract bit.

        Returns:
            ReplicatedSharedTensor : extracted bits at specific position.

        Raises:
            ValueError: If invalid position is provided.
        """
        ring_bits = get_nr_bits(self.ring_size)
        if pos < 0 or pos > ring_bits - 1:
            raise ValueError(
                f"Invalid position for bit_extraction: {pos}, must be in range:[0,{ring_bits-1}]"
            )
        shares = []
        # logical shift
        bit_mask = torch.ones(self.shares[0].shape,
                              dtype=self.shares[0].dtype) << pos
        shares = [share & bit_mask for share in self.shares]
        rst = ReplicatedSharedTensor(
            shares=shares,
            session_uuid=self.session_uuid,
            config=Config(encoder_base=1, encoder_precision=0),
            ring_size=2,
        )
        return rst

    def rmatmul(self, y):
        """Apply the "rmatmul" operation between "y" and "self".

        Args:
            y: self@y

        Raises:
            NotImplementedError: Raised when implementation not present
        """
        raise NotImplementedError

    def xor(
        self, y: Union[int, torch.Tensor, "ReplicatedSharedTensor"]
    ) -> "ReplicatedSharedTensor":
        """Apply the "xor" operation between "self" and "y".

        Args:
            y: public bit

        Returns:
            ReplicatedSharedTensor: Result of the operation.

        Raises:
            ValueError : If ring size is invalid.
        """
        if self.ring_size == 2:
            return self + y
        elif self.ring_size in RING_SIZE_TO_TYPE:
            return self + y - (self * y * 2)
        else:
            raise ValueError(
                f"The ring_size {self.ring_size} is not supported.")

    def lt(self, y):
        """Lower than operator.

        Args:
            y: self<y

        Raises:
            NotImplementedError: Raised when implementation not present
        """
        raise NotImplementedError

    def gt(self, y):
        """Greater than operator.

        Args:
            y: self>y

        Raises:
            NotImplementedError: Raised when implementation not present
        """
        raise NotImplementedError

    def eq(self, y: Any) -> bool:
        """Equal operator.

        Check if "self" is equal with another object given a set of attributes to compare.

        Args:
            y (Any): Object to compare

        Returns:
            bool: True if equal False if not.
        """
        if not (torch.cat(self.shares) == torch.cat(y.shares)).all():
            return False

        if self.config != y.config:
            return False

        if self.session_uuid and y.session_uuid and self.session_uuid != y.session_uuid:
            return False

        if self.ring_size != y.ring_size:
            return False

        return True

    def __getitem__(self, key: int) -> torch.Tensor:
        """Allows to subset shares.

        Args:
            key (int): The share to be retrieved.

        Returns:
            share (torch.Tensor): Returned share.
        """
        return self.shares[key]

    def __setitem__(self, key: int, newvalue: torch.Tensor) -> None:
        """Allows to set share value to new value.

        Args:
            key (int): The share to be retrieved.
            newvalue (torch.Tensor): New value of share.

        """
        self.shares[key] = newvalue

    def ne(self, y):
        """Not Equal operator.

        Args:
            y: self!=y

        Raises:
            NotImplementedError: Raised when implementation not present
        """
        raise NotImplementedError

    @staticmethod
    def shares_sum(shares: List[torch.Tensor], ring_size: int) -> torch.Tensor:
        """Returns sum of tensors based on ring_size.

        Args:
            shares (List[torch.Tensor]) : List of tensors.
            ring_size (int): Ring size of share associated with the tensors.

        Returns:
            value (torch.Tensor): sum of the tensors.
        """
        if ring_size == 2:
            return reduce(lambda x, y: x ^ y, shares)
        elif ring_size == PRIME_NUMBER:
            return reduce(ReplicatedSharedTensor.addmodprime, shares)
        else:
            return sum(shares)

    @staticmethod
    def _request_and_get(
        share_ptr: "ReplicatedSharedTensor", ) -> "ReplicatedSharedTensor":
        """Function used to request and get a share - Duet Setup.

        Args:
            share_ptr (ReplicatedSharedTensor): input ReplicatedSharedTensor

        Returns:
            ReplicatedSharedTensor : The ReplicatedSharedTensor in local.
        """
        if not ispointer(share_ptr):
            return share_ptr
        elif not islocal(share_ptr):
            share_ptr.request(block=True)
        res = share_ptr.get_copy()
        return res

    @staticmethod
    def __reconstruct_semi_honest(
        share_ptrs: List["ReplicatedSharedTensor"],
        get_shares: bool = False,
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
        """Reconstruct value from shares.

        Args:
            share_ptrs (List[ReplicatedSharedTensor]): List of RSTensor pointers.
            get_shares (bool): Retrieve only shares.

        Returns:
            reconstructed_value (torch.Tensor): Reconstructed value.
        """
        request = ReplicatedSharedTensor._request_and_get
        request_wrap = parallel_execution(request)
        args = [[share] for share in share_ptrs[:2]]
        local_shares = request_wrap(args)

        shares = [local_shares[0].shares[0]]
        shares.extend(local_shares[1].shares)

        if get_shares:
            return shares

        ring_size = local_shares[0].ring_size

        return ReplicatedSharedTensor.shares_sum(shares, ring_size)

    @staticmethod
    def __reconstruct_malicious(
        share_ptrs: List["ReplicatedSharedTensor"],
        get_shares: bool = False,
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
        """Reconstruct value from shares.

        Args:
            share_ptrs (List[ReplicatedSharedTensor]): List of RSTensor pointers.
            get_shares (bool): Retrieve only shares.

        Returns:
            reconstructed_value (torch.Tensor): Reconstructed value.

        Raises:
            ValueError: When parties share values are not equal.
        """
        nparties = len(share_ptrs)

        # Get shares from all parties
        request = ReplicatedSharedTensor._request_and_get
        request_wrap = parallel_execution(request)
        args = [[share] for share in share_ptrs]
        local_shares = request_wrap(args)
        ring_size = local_shares[0].ring_size
        shares_sum = ReplicatedSharedTensor.shares_sum

        all_shares = [rst.shares for rst in local_shares]
        # reconstruct shares from all parties and verify
        value = None
        for party_rank in range(nparties):
            tensor = shares_sum(
                [all_shares[party_rank][0]] + all_shares[(party_rank + 1) %
                                                         (nparties)],
                ring_size,
            )

            if value is None:
                value = tensor
            elif (tensor != value).any():
                raise ValueError(
                    "Reconstruction values from all parties are not equal.")

        if get_shares:
            return all_shares

        return value

    @staticmethod
    def reconstruct(
        share_ptrs: List["ReplicatedSharedTensor"],
        get_shares: bool = False,
        security_type: str = "semi-honest",
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
        """Reconstruct value from shares.

        Args:
            share_ptrs (List[ReplicatedSharedTensor]): List of RSTensor pointers.
            security_type (str): Type of security followed by protocol.
            get_shares (bool): Retrieve only shares.

        Returns:
            reconstructed_value (torch.Tensor): Reconstructed value.

        Raises:
            ValueError: Invalid security type.
            ValueError : SharePointers not provided.
        """
        if not len(share_ptrs):
            raise ValueError(
                "Share pointers must be provided for reconstruction.")
        if security_type == "malicious":
            return ReplicatedSharedTensor.__reconstruct_malicious(
                share_ptrs, get_shares)

        elif security_type == "semi-honest":
            return ReplicatedSharedTensor.__reconstruct_semi_honest(
                share_ptrs, get_shares)

        raise ValueError("Invalid security Type")

    @staticmethod
    def distribute_shares_to_party(
        shares: List[Union[ShareTensor, torch.Tensor]],
        party_rank: int,
        session: Session,
        ring_size: int,
        config: Config,
    ) -> "ReplicatedSharedTensor":
        """Distributes shares to party.

        Args:
            shares (List[Union[ShareTensor,torch.Tensor]]): Shares to be distributed.
            party_rank (int): Rank of party.
            session (Session): Current session
            ring_size (int): Ring size of tensor to distribute
            config (Config): The configuration(base,precision) of the tensor.

        Returns:
            tensor (ReplicatedSharedTensor): Tensor with shares

        Raises:
            TypeError: Invalid share class.
        """
        party = session.parties[party_rank]
        nshares = session.nr_parties - 1
        party_shares = []

        for share_index in range(party_rank, party_rank + nshares):
            share = shares[share_index % (nshares + 1)]

            if isinstance(share, torch.Tensor):
                party_shares.append(share)

            elif isinstance(share, ShareTensor):
                party_shares.append(share.tensor)

            else:
                raise TypeError(f"{type(share)} is an invalid share class")

        tensor = ReplicatedSharedTensor(
            session_uuid=session.rank_to_uuid[party_rank],
            config=config,
            ring_size=ring_size,
        )
        tensor.shares = party_shares
        return tensor.send(party)

    @staticmethod
    def distribute_shares(
        shares: List[Union[ShareTensor, torch.Tensor]],
        session: Session,
        ring_size: Optional[int] = None,
        config: Optional[Config] = None,
    ) -> List["ReplicatedSharedTensor"]:
        """Distribute a list of shares.

        Args:
            shares (List[ShareTensor): list of shares to distribute.
            session (Session): Session.
            ring_size (int): ring_size the shares belong to.
            config (Config): The configuration(base,precision) of the tensor.

        Returns:
            List of ReplicatedSharedTensors.

        Raises:
            TypeError: when Datatype of shares is invalid.

        """
        if not isinstance(shares, (list, tuple)):
            raise TypeError(
                "Shares to be distributed should be a list of shares")

        if len(shares) != session.nr_parties:
            return ValueError(
                "Number of shares to be distributed should be same as number of parties"
            )

        if ring_size is None:
            ring_size = session.ring_size
        if config is None:
            config = session.config

        args = [[shares, party_rank, session, ring_size, config]
                for party_rank in range(session.nr_parties)]

        return [
            ReplicatedSharedTensor.distribute_shares_to_party(*arg)
            for arg in args
        ]

    @staticmethod
    def hook_property(property_name: str) -> Any:
        """Hook a framework property (only getter).

        Ex:
         * if we call "shape" we want to call it on the underlying tensor
        and return the result
         * if we call "T" we want to call it on the underlying tensor
        but we want to wrap it in the same tensor type

        Args:
            property_name (str): property to hook

        Returns:
            A hooked property
        """
        def property_new_rs_tensor_getter(
                _self: "ReplicatedSharedTensor") -> Any:
            shares = []

            for share in _self.shares:
                tensor = getattr(share, property_name)
                shares.append(tensor)

            res = ReplicatedSharedTensor(
                session_uuid=_self.session_uuid,
                config=_self.config,
                ring_size=_self.ring_size,
            )
            res.shares = shares
            return res

        def property_getter(_self: "ReplicatedSharedTensor") -> Any:
            prop = getattr(_self.shares[0], property_name)
            return prop

        if property_name in PROPERTIES_NEW_RS_TENSOR:
            res = property(property_new_rs_tensor_getter, None)
        else:
            res = property(property_getter, None)

        return res

    @staticmethod
    def hook_method(method_name: str) -> Callable[..., Any]:
        """Hook a framework method such that we know how to treat it given that we call it.

        Ex:
         * if we call "numel" we want to call it on the underlying tensor
        and return the result
         * if we call "unsqueeze" we want to call it on the underlying tensor
        but we want to wrap it in the same tensor type

        Args:
            method_name (str): method to hook

        Returns:
            A hooked method

        """
        def method_new_rs_tensor(_self: "ReplicatedSharedTensor",
                                 *args: List[Any], **kwargs: Dict[Any,
                                                                  Any]) -> Any:
            shares = []
            for share in _self.shares:
                tensor = getattr(share, method_name)(*args, **kwargs)
                shares.append(tensor)

            res = ReplicatedSharedTensor(
                session_uuid=_self.session_uuid,
                config=_self.config,
                ring_size=_self.ring_size,
            )
            res.shares = shares
            return res

        def method(_self: "ReplicatedSharedTensor", *args: List[Any],
                   **kwargs: Dict[Any, Any]) -> Any:
            method = getattr(_self.shares[0], method_name)
            res = method(*args, **kwargs)
            return res

        if method_name in METHODS_NEW_RS_TENSOR:
            res = method_new_rs_tensor
        else:
            res = method

        return res

    __add__ = add
    __radd__ = add
    __sub__ = sub
    __rsub__ = rsub
    __mul__ = mul
    __rmul__ = mul
    __matmul__ = matmul
    __rmatmul__ = rmatmul
    __truediv__ = truediv
    __floordiv__ = truediv
    __xor__ = xor
    __eq__ = eq
    __rshift__ = rshift
Ejemplo n.º 11
0
class ShareTensor:
    """This class represents 1 share that a party holds when doing secret
    sharing.

    Arguments:
        data (Optional[Any]): the share a party holds
        session (Optional[Any]): the session from which this shares belongs to
        encoder_base (int): the base for the encoder
        encoder_precision (int): the precision for the encoder
        ring_size (int): field used for the operations applied on the shares

    Attributes:
        Syft Serializable Attributes

        id (UID): the id to store the session
        tags (Optional[List[str]): an optional list of strings that are tags used at search
        description (Optional[str]): an optional string used to describe the session

        tensor (Any): the value of the share
        session (Session): keep track from which session  this share belongs to
        fp_encoder (FixedPointEncoder): the encoder used to convert a share from/to fixed point
    """

    __slots__ = {
        # Populated in Syft
        "id",
        "tags",
        "description",
        "tensor",
        "session",
        "fp_encoder",
    }

    def __init__(
        self,
        data: Optional[Union[float, int, torch.Tensor]] = None,
        session: Optional[Session] = None,
        encoder_base: int = 2,
        encoder_precision: int = 16,
        ring_size: int = 2**64,
    ) -> None:
        """Initializer for the ShareTensor."""

        if session is None:
            self.session = Session(ring_size=ring_size, )
            self.session.config.encoder_precision = encoder_precision
            self.session.config.encoder_base = encoder_base

        else:
            self.session = session
            encoder_precision = self.session.config.encoder_precision
            encoder_base = self.session.config.encoder_base

        # TODO: It looks like the same logic as above
        self.fp_encoder = FixedPointEncoder(base=encoder_base,
                                            precision=encoder_precision)

        self.tensor: Optional[torch.Tensor] = None

        if data is not None:
            tensor_type = self.session.tensor_type
            self.tensor = self._encode(data).type(tensor_type)

    def _encode(self, data):
        return self.fp_encoder.encode(data)

    def decode(self):
        return self._decode()

    def _decode(self):
        return self.fp_encoder.decode(self.tensor.type(torch.LongTensor))

    @staticmethod
    def sanity_checks(x: "ShareTensor", y: Union[int, float, torch.Tensor,
                                                 "ShareTensor"],
                      op_str: str) -> "ShareTensor":
        """Check the type of "y" and convert it to a share if necessary.

        :return: the y value
        :rtype: ShareTensor, int or Integer type Tensor
        """
        if not isinstance(y, ShareTensor):
            y = ShareTensor(data=y, session=x.session)

        return y

    def apply_function(self, y: Union["ShareTensor", torch.Tensor, int, float],
                       op_str: str) -> "ShareTensor":
        """Apply a given operation.

        :return: the result of applying "op_str" on "self" and y
        :rtype: ShareTensor
        """
        op = getattr(operator, op_str)

        if isinstance(y, ShareTensor):
            value = op(self.tensor, y.tensor)
        else:
            value = op(self.tensor, y)

        res = ShareTensor(session=self.session)
        res.tensor = value
        return res

    def add(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "add" operation between "self" and "y".

        :return: self + y
        :rtype: ShareTensor
        """
        y_share = ShareTensor.sanity_checks(self, y, "add")
        res = self.apply_function(y_share, "add")
        return res

    def sub(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "sub" operation between "self" and "y".

        :return: self - y
        :rtype: ShareTensor
        """
        y_share = ShareTensor.sanity_checks(self, y, "sub")
        res = self.apply_function(y_share, "sub")
        return res

    def rsub(
            self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "sub" operation between "y" and "self".

        :return: y - self
        :rtype: ShareTensor
        """
        y_share = ShareTensor.sanity_checks(self, y, "sub")
        res = y_share.apply_function(self, "sub")
        return res

    def mul(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "mul" operation between "self" and "y".

        :return: self * y
        :rtype: ShareTensor
        """
        y = ShareTensor.sanity_checks(self, y, "mul")
        res = self.apply_function(y, "mul")

        if self.session.nr_parties == 0:
            # We are using a simple share without usig the MPCTensor
            # In case we used the MPCTensor - the division would have
            # been done in the protocol
            res.tensor = res.tensor // self.fp_encoder.scale

        return res

    def matmul(
            self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "matmul" operation between "self" and "y".

        :return: self @ y
        :rtype: ShareTensor
        """
        y = ShareTensor.sanity_checks(self, y, "matmul")
        res = self.apply_function(y, "matmul")

        if self.session.nr_parties == 0:
            # We are using a simple share without usig the MPCTensor
            # In case we used the MPCTensor - the division would have
            # been done in the protocol
            res.tensor = res.tensor // self.fp_encoder.scale

        return res

    def rmatmul(self, y: torch.Tensor) -> "ShareTensor":
        """Apply the reversed "matmul" operation between "self" and "y".

        :return: y @ self
        :rtype: ShareTensor
        """
        y = ShareTensor.sanity_checks(self, y, "matmul")
        return y.matmul(self)

    def div(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "div" operation between "self" and "y". Currently,
        NotImplemented.

        :return: self / y
        :rtype: ShareTensor
        """
        if not isinstance(y, (int, torch.LongTensor)):
            raise ValueError("Div works (for the moment) only with integers!")

        res = ShareTensor(session=self.session)
        res.tensor = self.tensor // y

        return res

    def __getattr__(self, attr_name: str) -> Any:
        """Get the attribute from the ShareTensor. If the attribute is not
        found at the ShareTensor level, the it would look for in the the
        "tensor".

        :return: the attribute value
        :rtype: Anything
        """
        # Default to some tensor specific attributes like
        # size, shape, etc.
        tensor = self.tensor
        return getattr(tensor, attr_name)

    def __gt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool:
        """Check if "self" is bigger than "y".

        :return: self > y
        :rtype: bool
        """
        y_share = ShareTensor.sanity_checks(self, y, "gt")
        res = self.tensor > y_share.tensor
        return res

    def __lt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool:
        """Check if "self" is less than "y".

        :return: self < y
        :rtype: bool
        """

        y_share = ShareTensor.sanity_checks(self, y, "lt")
        res = self.tensor < y_share.tensor
        return res

    def __str__(self) -> str:
        """Return the string representation of ShareTensor."""
        type_name = type(self).__name__
        out = f"[{type_name}]"
        out = f"{out}\n\t| {self.fp_encoder}"
        out = f"{out}\n\t| Data: {self.tensor}"

        return out

    def __repr__(self):
        return self.__str__()

    def __eq__(self, other: Any) -> bool:
        """Check if "self" is equal with another object given a set of
        attributes to compare.

        :return: if self and other are equal
        :rtype: bool
        """

        if not (self.tensor == other.tensor).all():
            return False

        if not (self.session == other.session):
            return False

        return True

    __add__ = add
    __radd__ = add
    __sub__ = sub
    __rsub__ = rsub
    __mul__ = mul
    __rmul__ = mul
    __matmul__ = matmul
    __rmatmul__ = rmatmul
    __truediv__ = div
Ejemplo n.º 12
0
class ReplicatedSharedTensor(metaclass=SyMPCTensor):
    """RSTensor is used when a party holds more than a single share,required by various protocols.

    Arguments:
       shares (Optional[List[Union[float, int, torch.Tensor]]]): Shares list
           from which RSTensor is created.


    Attributes:
       shares: The shares held by the party
    """

    AUTOGRAD_IS_ON: bool = True

    # Used by the SyMPCTensor metaclass
    METHODS_FORWARD = {"numel", "t", "unsqueeze", "view", "sum", "clone"}
    PROPERTIES_FORWARD = {"T"}

    def __init__(
        self,
        shares: Optional[List[Union[float, int, torch.Tensor]]] = None,
        config: Config = Config(encoder_base=2, encoder_precision=16),
        session_uuid: Optional[UUID] = None,
        ring_size: int = 2**64,
    ):
        """Initialize ReplicatedSharedTensor.

        Args:
            shares (Optional[List[Union[float, int, torch.Tensor]]]): Shares list
                from which RSTensor is created.
            config (Config): The configuration where we keep the encoder precision and base.
            session_uuid (Optional[UUID]): Used to keep track of a share that is associated with a
                remote session
            ring_size (int): field used for the operations applied on the shares
                Defaults to 2**64

        """
        self.session_uuid = session_uuid
        self.ring_size = ring_size

        self.config = config
        self.fp_encoder = FixedPointEncoder(base=config.encoder_base,
                                            precision=config.encoder_precision)

        tensor_type = get_type_from_ring(ring_size)
        self.shares = []
        if shares is not None:
            for i in range(len(shares)):
                self.shares.append(self._encode(shares[i]).to(tensor_type))

    def _encode(self, data):
        return self.fp_encoder.encode(data)

    def decode(self):
        """Decode via FixedPointEncoder.

        Returns:
            List[torch.Tensor]: Decoded values
        """
        return self._decode()

    def _decode(self):
        """Decodes shares list of RSTensor via FixedPointEncoder.

        Returns:
            List[torch.Tensor]: Decoded values
        """
        shares = []
        for i in range(len(self.shares)):
            tensor = self.fp_encoder.decode(self.shares[i].type(
                torch.LongTensor))
            shares.append(tensor)

        return shares

    def add(self, y):
        """Apply the "add" operation between "self" and "y".

        Args:
            y: self+y

        """

    def sub(self, y):
        """Apply the "sub" operation between "self" and "y".

        Args:
            y: self-y


        """

    def rsub(self, y):
        """Apply the "sub" operation between "y" and "self".

        Args:
            y: self-y

        """

    def mul(self, y):
        """Apply the "mul" operation between "self" and "y".

        Args:
            y: self*y

        """

    def truediv(self, y):
        """Apply the "div" operation between "self" and "y".

        Args:
            y: self/y

        """

    def matmul(self, y):
        """Apply the "matmul" operation between "self" and "y".

        Args:
            y: self@y

        """

    def rmatmul(self, y):
        """Apply the "rmatmul" operation between "y" and "self".

        Args:
            y: self@y

        """

    def xor(self, y):
        """Apply the "xor" operation between "self" and "y".

        Args:
            y: self^y

        """

    def lt(self, y):
        """Lower than operator.

        Args:
            y: self<y

        """

    def gt(self, y):
        """Greater than operator.

        Args:
            y: self>y

        """

    def eq(self, y):
        """Equal operator.

        Args:
            y: self==y

        """

    def ne(self, y):
        """Not Equal operator.

        Args:
            y: self!=y

        """

    @staticmethod
    def hook_property(property_name: str) -> Any:
        """Hook a framework property (only getter).

        Ex:
         * if we call "shape" we want to call it on the underlying tensor
        and return the result
         * if we call "T" we want to call it on the underlying tensor
        but we want to wrap it in the same tensor type

        Args:
            property_name (str): property to hook

        Returns:
            A hooked property
        """
        def property_new_rs_tensor_getter(
                _self: "ReplicatedSharedTensor") -> Any:
            shares = []

            for i in range(len(_self.shares)):
                tensor = getattr(_self.shares[i], property_name)
                shares.append(tensor)

            res = ReplicatedSharedTensor(session_uuid=_self.session_uuid,
                                         config=_self.config)
            res.shares = shares
            return res

        def property_getter(_self: "ReplicatedSharedTensor") -> Any:
            prop = getattr(_self.shares[0], property_name)
            return prop

        if property_name in PROPERTIES_NEW_RS_TENSOR:
            res = property(property_new_rs_tensor_getter, None)
        else:
            res = property(property_getter, None)

        return res

    @staticmethod
    def hook_method(method_name: str) -> Callable[..., Any]:
        """Hook a framework method such that we know how to treat it given that we call it.

        Ex:
         * if we call "numel" we want to call it on the underlying tensor
        and return the result
         * if we call "unsqueeze" we want to call it on the underlying tensor
        but we want to wrap it in the same tensor type

        Args:
            method_name (str): method to hook

        Returns:
            A hooked method

        """
        def method_new_rs_tensor(_self: "ReplicatedSharedTensor",
                                 *args: List[Any], **kwargs: Dict[Any,
                                                                  Any]) -> Any:
            shares = []
            for i in range(len(_self.shares)):
                tensor = getattr(_self.shares[i], method_name)(*args, **kwargs)
                shares.append(tensor)

            res = ReplicatedSharedTensor(session_uuid=_self.session_uuid,
                                         config=_self.config)
            res.shares = shares
            return res

        def method(_self: "ReplicatedSharedTensor", *args: List[Any],
                   **kwargs: Dict[Any, Any]) -> Any:
            method = getattr(_self.shares[0], method_name)
            res = method(*args, **kwargs)
            return res

        if method_name in METHODS_NEW_RS_TENSOR:
            res = method_new_rs_tensor
        else:
            res = method

        return res

    __add__ = add
    __radd__ = add
    __sub__ = sub
    __rsub__ = rsub
    __mul__ = mul
    __rmul__ = mul
    __matmul__ = matmul
    __rmatmul__ = rmatmul
    __truediv__ = truediv
    __xor__ = xor
Ejemplo n.º 13
0
class ShareTensor(metaclass=SyMPCTensor):
    """Single Share representation.

    Arguments:
        data (Optional[Any]): the share a party holds
        session (Optional[Any]): the session from which this shares belongs to
        encoder_base (int): the base for the encoder
        encoder_precision (int): the precision for the encoder
        ring_size (int): field used for the operations applied on the shares

    Attributes:
        Syft Serializable Attributes

        id (UID): the id to store the session
        tags (Optional[List[str]): an optional list of strings that are tags used at search
        description (Optional[str]): an optional string used to describe the session

        tensor (Any): the value of the share
        session (Session): keep track from which session  this share belongs to
        fp_encoder (FixedPointEncoder): the encoder used to convert a share from/to fixed point
    """

    __slots__ = {
        # Populated in Syft
        "id",
        "tags",
        "description",
        "tensor",
        "session",
        "fp_encoder",
    }

    # Used by the SyMPCTensor metaclass
    METHODS_FORWARD: Set[str] = {
        "numel",
        "unsqueeze",
        "t",
        "view",
        "sum",
        "clone",
        "flatten",
        "reshape",
    }
    PROPERTIES_FORWARD: Set[str] = {"T", "shape"}

    def __init__(
        self,
        data: Optional[Union[float, int, torch.Tensor]] = None,
        session: Optional[Session] = None,
        encoder_base: int = 2,
        encoder_precision: int = 16,
        ring_size: int = 2**64,
    ) -> None:
        """Initialize ShareTensor.

        Args:
            data (Optional[Any]): The share a party holds. Defaults to None
            session (Optional[Any]): The session from which this shares belongs to.
                Defaults to None.
            encoder_base (int): The base for the encoder. Defaults to 2.
            encoder_precision (int): the precision for the encoder. Defaults to 16.
            ring_size (int): field used for the operations applied on the shares
                Defaults to 2**64
        """
        if session is None:
            self.session = Session(ring_size=ring_size, )
            self.session.config.encoder_precision = encoder_precision
            self.session.config.encoder_base = encoder_base

        else:
            self.session = session
            encoder_precision = self.session.config.encoder_precision
            encoder_base = self.session.config.encoder_base

        # TODO: It looks like the same logic as above
        self.fp_encoder = FixedPointEncoder(base=encoder_base,
                                            precision=encoder_precision)

        self.tensor: Optional[torch.Tensor] = None
        if data is not None:
            tensor_type = self.session.tensor_type
            self.tensor = self._encode(data).type(tensor_type)

    def _encode(self, data):
        return self.fp_encoder.encode(data)

    def decode(self):
        """Decode via FixedPrecisionEncoder.

        Returns:
            torch.Tensor: Decoded value
        """
        return self._decode()

    def _decode(self):
        return self.fp_encoder.decode(self.tensor.type(torch.LongTensor))

    @staticmethod
    def sanity_checks(x: "ShareTensor", y: Union[int, float, torch.Tensor,
                                                 "ShareTensor"],
                      op_str: str) -> "ShareTensor":
        """Check the type of "y" and covert it to share if necessary.

        Args:
            x (ShareTensor): Typically "self".
            y (Union[int, float, torch.Tensor, "ShareTensor"]): Tensor to check.
            op_str (str): String operator.

        Returns:
            ShareTensor: the converted y value.

        """
        if not isinstance(y, ShareTensor):
            y = ShareTensor(data=y, session=x.session)

        return y

    def apply_function(self, y: Union["ShareTensor", torch.Tensor, int, float],
                       op_str: str) -> "ShareTensor":
        """Apply a given operation.

        Args:
            y (Union["ShareTensor", torch.Tensor, int, float]): tensor to apply the operator.
            op_str (str): Operator.

        Returns:
            ShareTensor: Result of the operation.
        """
        op = getattr(operator, op_str)

        if isinstance(y, ShareTensor):
            value = op(self.tensor, y.tensor)
        else:
            value = op(self.tensor, y)

        res = ShareTensor(session=self.session)
        res.tensor = value
        return res

    def add(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "add" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): self + y

        Returns:
            ShareTensor. Result of the operation.
        """
        y_share = ShareTensor.sanity_checks(self, y, "add")
        res = self.apply_function(y_share, "add")
        return res

    def sub(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "sub" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): self - y

        Returns:
            ShareTensor. Result of the operation.
        """
        y_share = ShareTensor.sanity_checks(self, y, "sub")
        res = self.apply_function(y_share, "sub")
        return res

    def rsub(
            self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "sub" operation between "y" and "self".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): y - self

        Returns:
            ShareTensor. Result of the operation.
        """
        y_share = ShareTensor.sanity_checks(self, y, "sub")
        res = y_share.apply_function(self, "sub")
        return res

    def mul(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "mul" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): self * y

        Returns:
            ShareTensor. Result of the operation.
        """
        y = ShareTensor.sanity_checks(self, y, "mul")
        res = self.apply_function(y, "mul")

        if self.session.nr_parties == 0:
            # We are using a simple share without usig the MPCTensor
            # In case we used the MPCTensor - the division would have
            # been done in the protocol
            res.tensor //= self.fp_encoder.scale

        return res

    def xor(self, y: Union[int, torch.Tensor, "ShareTensor"]) -> "ShareTensor":
        """Apply the "xor" operation between "self" and "y".

        Args:
            y (Union[int, torch.Tensor, "ShareTensor"]): self xor y

        Returns:
            ShareTensor: Result of the operation.
        """
        res = ShareTensor(session=self.session)

        if isinstance(y, ShareTensor):
            res.tensor = self.tensor ^ y.tensor
        else:
            res.tensor = self.tensor ^ y

        return res

    def matmul(
            self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "matmul" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): self @ y.

        Returns:
            ShareTensor: Result of the operation.
        """
        y = ShareTensor.sanity_checks(self, y, "matmul")
        res = self.apply_function(y, "matmul")

        if self.session.nr_parties == 0:
            # We are using a simple share without usig the MPCTensor
            # In case we used the MPCTensor - the division would have
            # been done in the protocol
            res.tensor = res.tensor // self.fp_encoder.scale

        return res

    def rmatmul(self, y: torch.Tensor) -> "ShareTensor":
        """Apply the "rmatmul" operation between "y" and "self".

        Args:
            y (torch.Tensor): y @ self

        Returns:
            ShareTensor. Result of the operation.
        """
        y = ShareTensor.sanity_checks(self, y, "matmul")
        return y.matmul(self)

    def div(self, y: Union[int, float, torch.Tensor,
                           "ShareTensor"]) -> "ShareTensor":
        """Apply the "div" operation between "self" and "y".

        Args:
            y (Union[int, float, torch.Tensor, "ShareTensor"]): Denominator.

        Returns:
            ShareTensor: Result of the operation.

        Raises:
            ValueError: If y is not an integer or LongTensor.
        """
        if not isinstance(y, (int, torch.LongTensor)):
            raise ValueError("Div works (for the moment) only with integers!")

        res = ShareTensor(session=self.session)
        res.tensor = self.tensor // y

        return res

    def __gt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool:
        """Greater than operator.

        Args:
            y (Union["ShareTensor", torch.Tensor, int]): Tensor to compare.

        Returns:
            bool: Result of the comparison.
        """
        y_share = ShareTensor.sanity_checks(self, y, "gt")
        res = self.tensor > y_share.tensor
        return res

    def __lt__(self, y: Union["ShareTensor", torch.Tensor, int]) -> bool:
        """Lower than operator.

        Args:
            y (Union["ShareTensor", torch.Tensor, int]): Tensor to compare.

        Returns:
            bool: Result of the comparison.
        """
        y_share = ShareTensor.sanity_checks(self, y, "lt")
        res = self.tensor < y_share.tensor
        return res

    def __str__(self) -> str:
        """Representation.

        Returns:
            str: Return the string representation of ShareTensor.
        """
        type_name = type(self).__name__
        out = f"[{type_name}]"
        out = f"{out}\n\t| {self.fp_encoder}"
        out = f"{out}\n\t| Data: {self.tensor}"

        return out

    def __repr__(self) -> str:
        """Representation.

        Returns:
            String representation.
        """
        return self.__str__()

    def __eq__(self, other: Any) -> bool:
        """Equal operator.

        Check if "self" is equal with another object given a set of
            attributes to compare.

        Args:
            other (Any): Tensor to compare.

        Returns:
            bool: True if equal False if not.

        """
        if not (self.tensor == other.tensor).all():
            return False

        if not (self.session == other.session):
            return False

        return True

    @staticmethod
    def hook_property(property_name: str) -> Any:
        """Hook a framework property (only getter).

        Ex:
         * if we call "shape" we want to call it on the underlying tensor
        and return the result
         * if we call "T" we want to call it on the underlying tensor
        but we want to wrap it in the same tensor type

        Args:
            property_name (str): property to hook

        Returns:
            A hooked property
        """
        def property_new_share_tensor_getter(_self: "ShareTensor") -> Any:
            tensor = getattr(_self.tensor, property_name)
            res = ShareTensor(session=_self.session)
            res.tensor = tensor
            return res

        def property_getter(_self: "ShareTensor") -> Any:
            prop = getattr(_self.tensor, property_name)
            return prop

        if property_name in PROPERTIES_NEW_SHARE_TENSOR:
            res = property(property_new_share_tensor_getter, None)
        else:
            res = property(property_getter, None)

        return res

    @staticmethod
    def hook_method(method_name: str) -> Callable[..., Any]:
        """Hook a framework method such that we know how to treat it given that we call it.

        Ex:
         * if we call "numel" we want to call it on the underlying tensor
        and return the result
         * if we call "unsqueeze" we want to call it on the underlying tensor
        but we want to wrap it in the same tensor type

        Args:
            method_name (str): method to hook

        Returns:
            A hooked method
        """
        def method_new_share_tensor(_self: "ShareTensor", *args: List[Any],
                                    **kwargs: Dict[Any, Any]) -> Any:
            method = getattr(_self.tensor, method_name)
            tensor = method(*args, **kwargs)
            res = ShareTensor(session=_self.session)
            res.tensor = tensor
            return res

        def method(_self: "ShareTensor", *args: List[Any],
                   **kwargs: Dict[Any, Any]) -> Any:
            method = getattr(_self.tensor, method_name)
            res = method(*args, **kwargs)
            return res

        if method_name in METHODS_NEW_SHARE_TENSOR:
            res = method_new_share_tensor
        else:
            res = method

        return res

    __add__ = add
    __radd__ = add
    __sub__ = sub
    __rsub__ = rsub
    __mul__ = mul
    __rmul__ = mul
    __matmul__ = matmul
    __rmatmul__ = rmatmul
    __truediv__ = div
    __xor__ = xor