def _abstract_mul(self, equation: str, shares: dict, other_shares, **kwargs): """Abstractly Multiplies two tensors Args: equation: a string reprsentation of the equation to be computed in einstein summation form shares: a dictionary <location_id -> PointerTensor) of shares corresponding to self. Equivalent to calling self.child. other_shares: a dictionary <location_id -> PointerTensor) of shares corresponding to the tensor being multiplied by self. """ # check to see that operation is either mul or matmul assert equation == "mul" or equation == "matmul" cmd = getattr(torch, equation) # if someone passes in a constant... (i.e., x + 3) # TODO: Handle public mul more efficiently if not isinstance(other_shares, dict): other_shares = torch.Tensor([other_shares]).share(*self.child.keys()).child assert len(shares) == len(other_shares) if self.crypto_provider is None: raise AttributeError("For multiplication a crytoprovider must be passed.") shares = spdz_mul(cmd, shares, other_shares, self.crypto_provider, self.field) return AdditiveSharingTensor( field=self.field, crypto_provider=self.crypto_provider ).set_shares(shares)
def _private_mul(self, other, equation: str): """Abstractly Multiplies two tensors Args: self: an AdditiveSharingTensor other: another AdditiveSharingTensor equation: a string representation of the equation to be computed in einstein summation form """ # check to see that operation is either mul or matmul assert equation == "mul" or equation == "matmul" cmd = getattr(torch, equation) assert isinstance(other, AdditiveSharingTensor) assert len(self.child) == len(other.child) if self.crypto_provider is None: raise AttributeError( "For multiplication a crypto_provider must be passed.") shares = spdz.spdz_mul(cmd, self, other, self.crypto_provider, self.field) return shares