def share( self, *owners: List[BaseWorker], protocol: str = "snn", field: Union[int, None] = None, dtype: Union[str, None] = None, crypto_provider: Union[BaseWorker, None] = None, requires_grad: bool = False, no_wrap: bool = False, ): """This is a pass through method which calls .share on the child. Args: owners (list): A list of BaseWorker objects determining who to send shares to. protocol (str): the crypto protocol used to perform the computations ('snn' or 'fss') field (int or None): The arithmetic field where live the shares. dtype (str or None): The dtype of shares crypto_provider (BaseWorker or None): The worker providing the crypto primitives. requires_grad (bool): Should we add AutogradTensor to allow gradient computation, default is False. """ if protocol == "falcon": shared_tensor = syft.ReplicatedSharingTensor( owner=self.owner).share_secret(self, owners) return shared_tensor if self.has_child(): chain = self.child kwargs_ = ({ "requires_grad": requires_grad } if isinstance(chain, syft.PointerTensor) else {}) shared_tensor = chain.share( *owners, protocol=protocol, field=field, dtype=dtype, crypto_provider=crypto_provider, **kwargs_, ) else: if self.type() == "torch.FloatTensor": raise TypeError( "FloatTensor cannot be additively shared, Use fix_precision." ) shared_tensor = (syft.AdditiveSharingTensor( protocol=protocol, field=field, dtype=dtype, crypto_provider=crypto_provider, owner=self.owner, ).on(self.copy(), wrap=False).share_secret(*owners)) if requires_grad and not isinstance(shared_tensor, syft.PointerTensor): shared_tensor = syft.AutogradTensor().on(shared_tensor, wrap=False) if not no_wrap: shared_tensor = shared_tensor.wrap(type=self.dtype) return shared_tensor
def public_linear_operation(self, plain_text, operator): players = self.get_players() shares_map = self.get_shares_map() plain_text = torch.tensor(plain_text).send(players[0]) shares_map[players[0]] = ( operator(shares_map[players[0]][0], plain_text), shares_map[players[0]][1], ) return syft.ReplicatedSharingTensor(shares_map)
def test_shares_number(): tensor = syft.ReplicatedSharingTensor() secret = torch.tensor(7) number_of_shares = 4 shares = tensor.generate_shares(secret, number_of_shares) assert len(shares) == number_of_shares