コード例 #1
0
    def share(
        self,
        *owners: List[BaseWorker],
        protocol: str = "snn",
        field: Union[int, None] = None,
        dtype: Union[str, None] = None,
        crypto_provider: Union[BaseWorker, None] = None,
        requires_grad: bool = False,
        no_wrap: bool = False,
    ):
        """This is a pass through method which calls .share on the child.

        Args:
            owners (list): A list of BaseWorker objects determining who to send shares to.
            protocol (str): the crypto protocol used to perform the computations ('snn' or 'fss')
            field (int or None): The arithmetic field where live the shares.
            dtype (str or None): The dtype of shares
            crypto_provider (BaseWorker or None): The worker providing the crypto primitives.
            requires_grad (bool): Should we add AutogradTensor to allow gradient computation,
                default is False.
        """
        if protocol == "falcon":
            shared_tensor = syft.ReplicatedSharingTensor(
                owner=self.owner).share_secret(self, owners)
            return shared_tensor
        if self.has_child():
            chain = self.child

            kwargs_ = ({
                "requires_grad": requires_grad
            } if isinstance(chain, syft.PointerTensor) else {})
            shared_tensor = chain.share(
                *owners,
                protocol=protocol,
                field=field,
                dtype=dtype,
                crypto_provider=crypto_provider,
                **kwargs_,
            )
        else:
            if self.type() == "torch.FloatTensor":
                raise TypeError(
                    "FloatTensor cannot be additively shared, Use fix_precision."
                )

            shared_tensor = (syft.AdditiveSharingTensor(
                protocol=protocol,
                field=field,
                dtype=dtype,
                crypto_provider=crypto_provider,
                owner=self.owner,
            ).on(self.copy(), wrap=False).share_secret(*owners))

        if requires_grad and not isinstance(shared_tensor, syft.PointerTensor):
            shared_tensor = syft.AutogradTensor().on(shared_tensor, wrap=False)

        if not no_wrap:
            shared_tensor = shared_tensor.wrap(type=self.dtype)

        return shared_tensor
コード例 #2
0
 def public_linear_operation(self, plain_text, operator):
     players = self.get_players()
     shares_map = self.get_shares_map()
     plain_text = torch.tensor(plain_text).send(players[0])
     shares_map[players[0]] = (
         operator(shares_map[players[0]][0], plain_text),
         shares_map[players[0]][1],
     )
     return syft.ReplicatedSharingTensor(shares_map)
コード例 #3
0
def test_shares_number():
    tensor = syft.ReplicatedSharingTensor()
    secret = torch.tensor(7)
    number_of_shares = 4
    shares = tensor.generate_shares(secret, number_of_shares)
    assert len(shares) == number_of_shares