Example #1
0
    def _abstract_mul(self, equation: str, shares: dict, other_shares, **kwargs):
        """Abstractly Multiplies two tensors

        Args:
            equation: a string reprsentation of the equation to be computed in einstein
                summation form
            shares: a dictionary <location_id -> PointerTensor) of shares corresponding to
                self. Equivalent to calling self.child.
            other_shares: a dictionary <location_id -> PointerTensor) of shares corresponding
                to the tensor being multiplied by self.
        """
        # check to see that operation is either mul or matmul
        assert equation == "mul" or equation == "matmul"
        cmd = getattr(torch, equation)

        # if someone passes in a constant... (i.e., x + 3)
        # TODO: Handle public mul more efficiently
        if not isinstance(other_shares, dict):
            other_shares = torch.Tensor([other_shares]).share(*self.child.keys()).child

        assert len(shares) == len(other_shares)

        if self.crypto_provider is None:
            raise AttributeError("For multiplication a crytoprovider must be passed.")

        shares = spdz_mul(cmd, shares, other_shares, self.crypto_provider, self.field)

        return AdditiveSharingTensor(
            field=self.field, crypto_provider=self.crypto_provider
        ).set_shares(shares)
Example #2
0
    def _private_mul(self, other, equation: str):
        """Abstractly Multiplies two tensors

        Args:
            self: an AdditiveSharingTensor
            other: another AdditiveSharingTensor
            equation: a string representation of the equation to be computed in einstein
                summation form
        """
        # check to see that operation is either mul or matmul
        assert equation == "mul" or equation == "matmul"
        cmd = getattr(torch, equation)

        assert isinstance(other, AdditiveSharingTensor)

        assert len(self.child) == len(other.child)

        if self.crypto_provider is None:
            raise AttributeError(
                "For multiplication a crypto_provider must be passed.")

        shares = spdz.spdz_mul(cmd, self, other, self.crypto_provider,
                               self.field)

        return shares