Esempio n. 1
0
    def rand(self, *sizes, encoder=None):
        if encoder is None:
            encoder = FixedPointEncoder()  # use default precision

        r = encoder.encode(torch.rand(*sizes))
        r = r - self._get_additive_PRSS(sizes, remove_rank=True)
        return r
Esempio n. 2
0
    def __init__(self, tensor=None, size=None, src=0, device=None):
        if src == SENTINEL:
            return
        assert (isinstance(src, int) and src >= 0
                and src < comm.get().get_world_size()), "invalid tensor source"

        if device is None and hasattr(tensor, "device"):
            device = tensor.device

        #  Assume 0 bits of precision unless encoder is set outside of init
        self.encoder = FixedPointEncoder(precision_bits=0)
        if tensor is not None:
            tensor = self.encoder.encode(tensor)
            tensor = tensor.to(device=device)
            size = tensor.size()

        # Generate Psuedo-random Sharing of Zero and add source's tensor
        self.share = BinarySharedTensor.PRZS(size, device=device).share
        if self.rank == src:
            assert tensor is not None, "Source must provide a data tensor"
            if hasattr(tensor, "src"):
                assert (
                    tensor.src == src
                ), "Source of data tensor must match source of encryption"
            self.share ^= tensor
Esempio n. 3
0
    def __init__(self,
                 tensor=None,
                 size=None,
                 precision=None,
                 src=0,
                 device=None):
        if src == SENTINEL:
            return
        assert (isinstance(src, int) and src >= 0
                and src < comm.get().get_world_size()), "invalid tensor source"

        if device is None and hasattr(tensor, "device"):
            device = tensor.device

        self.encoder = FixedPointEncoder(precision_bits=precision)
        if tensor is not None:
            if is_int_tensor(tensor) and precision != 0:
                tensor = tensor.float()
            tensor = self.encoder.encode(tensor)
            tensor = tensor.to(device=device)
            size = tensor.size()

        # Generate psuedo-random sharing of zero (PRZS) and add source's tensor
        self.share = ArithmeticSharedTensor.PRZS(size, device=device).share
        if self.rank == src:
            assert tensor is not None, "Source must provide a data tensor"
            if hasattr(tensor, "src"):
                assert (
                    tensor.src == src
                ), "Source of data tensor must match source of encryption"
            self.share += tensor
Esempio n. 4
0
    def __init__(
        self, tensor=None, size=None, broadcast_size=False, src=0, device=None
    ):
        """
        Creates the shared tensor from the input `tensor` provided by party `src`.

        The other parties can specify a `tensor` or `size` to determine the size
        of the shared tensor object to create. In this case, all parties must
        specify the same (tensor) size to prevent the party's shares from varying
        in size, which leads to undefined behavior.

        Alternatively, the parties can set `broadcast_size` to `True` to have the
        `src` party broadcast the correct size. The parties who do not know the
        tensor size beforehand can provide an empty tensor as input. This is
        guaranteed to produce correct behavior but requires an additional
        communication round.

        The parties can also set the `precision` and `device` for their share of
        the tensor. If `device` is unspecified, it is set to `tensor.device`.
        """

        # do nothing if source is sentinel:
        if src == SENTINEL:
            return

        # assertions on inputs:
        assert (
            isinstance(src, int) and src >= 0 and src < comm.get().get_world_size()
        ), "specified source party does not exist"
        if self.rank == src:
            assert tensor is not None, "source must provide a data tensor"
            if hasattr(tensor, "src"):
                assert (
                    tensor.src == src
                ), "source of data tensor must match source of encryption"
        if not broadcast_size:
            assert (
                tensor is not None or size is not None
            ), "must specify tensor or size, or set broadcast_size"

        # if device is unspecified, try and get it from tensor:
        if device is None and tensor is not None and hasattr(tensor, "device"):
            device = tensor.device

        # assume zero bits of precision unless encoder is set outside of init:
        self.encoder = FixedPointEncoder(precision_bits=0)
        if tensor is not None:
            tensor = self.encoder.encode(tensor)
            tensor = tensor.to(device=device)
            size = tensor.size()

        # if other parties do not know tensor's size, broadcast the size:
        if broadcast_size:
            size = comm.get().broadcast_obj(size, src)

        # generate pseudo-random zero sharing (PRZS) and add source's tensor:
        self.share = BinarySharedTensor.PRZS(size, device=device).share
        if self.rank == src:
            self.share ^= tensor
Esempio n. 5
0
 def from_shares(share, precision=None, src=0, device=None):
     """Generate a BinarySharedTensor from a share from each party"""
     result = BinarySharedTensor(src=SENTINEL)
     share = share.to(device) if device is not None else share
     result.share = CUDALongTensor(share) if share.is_cuda else share
     result.encoder = FixedPointEncoder(precision_bits=precision)
     return result
Esempio n. 6
0
def _B2A(binary_tensor, precision=None, bits=None):
    if bits is None:
        bits = torch.iinfo(torch.long).bits

    if bits == 1:
        binary_bit = binary_tensor & 1
        arithmetic_tensor = beaver.B2A_single_bit(binary_bit)
    else:
        binary_bits = BinarySharedTensor.stack(
            [binary_tensor >> i for i in range(bits)])
        binary_bits = binary_bits & 1
        arithmetic_bits = beaver.B2A_single_bit(binary_bits)

        multiplier = torch.cat([
            torch.tensor([1], dtype=torch.long, device=binary_tensor.device) <<
            i for i in range(bits)
        ])
        while multiplier.dim() < arithmetic_bits.dim():
            multiplier = multiplier.unsqueeze(1)

        arithmetic_tensor = arithmetic_bits.mul_(multiplier).sum(0)

    arithmetic_tensor.encoder = FixedPointEncoder(precision_bits=precision)
    scale = arithmetic_tensor.encoder._scale // binary_tensor.encoder._scale
    arithmetic_tensor *= scale
    return arithmetic_tensor
Esempio n. 7
0
    def test_from_shares(self):
        """Tests crypten.nn.Module.set_parameter_from_shares() functionality."""

        # create simple model:
        input_size, output_size = 3, 10
        model = crypten.nn.Linear(input_size, output_size)

        # helper function that creates arithmetically shared tensor of some size:
        def _generate_parameters(size):
            num_parties = int(self.world_size)
            reference = get_random_test_tensor(size=size, is_float=False)
            zero_shares = generate_random_ring_element((num_parties, *size))
            zero_shares = zero_shares - zero_shares.roll(1, dims=0)
            shares = list(zero_shares.unbind(0))
            shares[0] += reference
            return shares, reference

        # generate new set of parameters:
        all_shares, all_references = {}, {}
        for name, param in model.named_parameters():
            shares, reference = _generate_parameters(param.size())
            share = comm.get().scatter(shares, 0)
            all_shares[name] = share
            all_references[name] = reference

        # cannot load parameters from share when model is not encrypted:
        with self.assertRaises(AssertionError):
            for name, share in all_shares.items():
                model.set_parameter_from_shares(name, share)

        # cannot load shares into non-existent parameters:
        model.encrypt()
        with self.assertRaises(ValueError):
            model.set_parameter_from_shares("__DUMMY__", None)

        # load parameter shares into model and check results:
        for name, share in all_shares.items():
            model.set_parameter_from_shares(name, share)
        model.decrypt()
        encoder = FixedPointEncoder()
        for name, param in model.named_parameters():
            reference = encoder.decode(all_references[name])
            self.assertTrue(torch.allclose(param, reference))
Esempio n. 8
0
def _B2A(binary_tensor, precision=None, bits=None):
    if bits is None:
        bits = torch.iinfo(torch.long).bits

    arithmetic_tensor = 0
    for i in range(bits):
        binary_bit = binary_tensor & 1
        arithmetic_bit = beaver.B2A_single_bit(binary_bit)
        # avoids long integer overflow since 2 ** 63 is out of range
        # (aliases to -2 ** 63)
        if i == 63:
            arithmetic_tensor += arithmetic_bit * (-2**63)
        else:
            arithmetic_tensor += arithmetic_bit * (2**i)
        binary_tensor >>= 1
    arithmetic_tensor.encoder = FixedPointEncoder(precision_bits=precision)
    scale = arithmetic_tensor.encoder._scale // binary_tensor.encoder._scale
    arithmetic_tensor *= scale
    return arithmetic_tensor
Esempio n. 9
0
 def from_shares(share, precision=None, src=0):
     """Generate a BinarySharedTensor from a share from each party"""
     result = BinarySharedTensor(src=SENTINEL)
     result.share = share
     result.encoder = FixedPointEncoder(precision_bits=precision)
     return result
Esempio n. 10
0
class BinarySharedTensor(object):
    """
        Encrypted tensor object that uses binary sharing to perform computations.

        Binary shares are computed by splitting each value of the input tensor
        into n separate random values that xor together to the input tensor value,
        where n is the number of parties present in the protocol (world_size).
    """

    def __init__(self, tensor=None, size=None, src=0):
        if src == SENTINEL:
            return
        assert (
            isinstance(src, int) and src >= 0 and src < comm.get().get_world_size()
        ), "invalid tensor source"

        #  Assume 0 bits of precision unless encoder is set outside of init
        self.encoder = FixedPointEncoder(precision_bits=0)
        if tensor is not None:
            tensor = self.encoder.encode(tensor)
            size = tensor.size()

        # Generate Psuedo-random Sharing of Zero and add source's tensor
        self.share = BinarySharedTensor.PRZS(size).share
        if self.rank == src:
            assert tensor is not None, "Source must provide a data tensor"
            if hasattr(tensor, "src"):
                assert (
                    tensor.src == src
                ), "Source of data tensor must match source of encryption"
            self.share ^= tensor

    @staticmethod
    def from_shares(share, precision=None, src=0):
        """Generate a BinarySharedTensor from a share from each party"""
        result = BinarySharedTensor(src=SENTINEL)
        result.share = share
        result.encoder = FixedPointEncoder(precision_bits=precision)
        return result

    @staticmethod
    def PRZS(*size):
        """
        Generate a Pseudo-random Sharing of Zero (using arithmetic shares)

        This function does so by generating `n` numbers across `n` parties with
        each number being held by exactly 2 parties. Therefore, each party holds
        two numbers. A zero sharing is found by having each party xor their two
        numbers together.
        """
        tensor = BinarySharedTensor(src=SENTINEL)
        current_share = generate_kbit_random_tensor(*size, generator=comm.get().g0)
        next_share = generate_kbit_random_tensor(*size, generator=comm.get().g1)
        tensor.share = current_share ^ next_share
        return tensor

    @property
    def rank(self):
        return comm.get().get_rank()

    @property
    def share(self):
        """Returns underlying _tensor"""
        return self._tensor

    @share.setter
    def share(self, value):
        """Sets _tensor to value"""
        self._tensor = value

    def shallow_copy(self):
        """Create a shallow copy"""
        result = BinarySharedTensor(src=SENTINEL)
        result.encoder = self.encoder
        result.share = self.share
        return result

    def copy_(self, other):
        """Copies other tensor into this tensor."""
        self.share.copy_(other.share)
        self.encoder = other.encoder

    def __repr__(self):
        return f"BinarySharedTensor({self.share})"

    def __bool__(self):
        """Override bool operator since encrypted tensors cannot evaluate"""
        raise RuntimeError("Cannot evaluate BinarySharedTensors to boolean values")

    def __nonzero__(self):
        """__bool__ for backwards compatibility with Python 2"""
        raise RuntimeError("Cannot evaluate BinarySharedTensors to boolean values")

    def __ixor__(self, y):
        """Bitwise XOR operator (element-wise) in place"""
        if torch.is_tensor(y) or isinstance(y, int):
            if self.rank == 0:
                self.share ^= y
        elif isinstance(y, BinarySharedTensor):
            self.share ^= y.share
        else:
            raise TypeError("Cannot XOR %s with %s." % (type(y), type(self)))
        return self

    def __xor__(self, y):
        """Bitwise XOR operator (element-wise)"""
        result = self.clone()
        if isinstance(y, BinarySharedTensor):
            broadcast_tensors = torch.broadcast_tensors(result.share, y.share)
            result.share = broadcast_tensors[0].clone()
        elif torch.is_tensor(y):
            broadcast_tensors = torch.broadcast_tensors(result.share, y)
            result.share = broadcast_tensors[0].clone()
        return result.__ixor__(y)

    def __iand__(self, y):
        """Bitwise AND operator (element-wise) in place"""
        if torch.is_tensor(y) or isinstance(y, int):
            self.share &= y
        elif isinstance(y, BinarySharedTensor):
            self.share.data = beaver.AND(self, y).share.data
        else:
            raise TypeError("Cannot AND %s with %s." % (type(y), type(self)))
        return self

    def __and__(self, y):
        """Bitwise AND operator (element-wise)"""
        result = self.clone()
        # TODO: Remove explicit broadcasts to allow smaller beaver triples
        if isinstance(y, BinarySharedTensor):
            broadcast_tensors = torch.broadcast_tensors(result.share, y.share)
            result.share = broadcast_tensors[0].clone()
        elif torch.is_tensor(y):
            broadcast_tensors = torch.broadcast_tensors(result.share, y)
            result.share = broadcast_tensors[0].clone()
        return result.__iand__(y)

    def __ior__(self, y):
        """Bitwise OR operator (element-wise) in place"""
        xor_result = self ^ y
        return self.__iand__(y).__ixor__(xor_result)

    def __or__(self, y):
        """Bitwise OR operator (element-wise)"""
        return self.__and__(y) ^ self ^ y

    def __invert__(self):
        """Bitwise NOT operator (element-wise)"""
        result = self.clone()
        if result.rank == 0:
            result.share ^= -1
        return result

    def lshift_(self, value):
        """Left shift elements by `value` bits"""
        assert isinstance(value, int), "lshift must take an integer argument."
        self.share <<= value
        return self

    def lshift(self, value):
        """Left shift elements by `value` bits"""
        return self.clone().lshift_(value)

    def rshift_(self, value):
        """Right shift elements by `value` bits"""
        assert isinstance(value, int), "rshift must take an integer argument."
        self.share >>= value
        return self

    def rshift(self, value):
        """Right shift elements by `value` bits"""
        return self.clone().rshift_(value)

    # Circuits
    def add(self, y):
        """Compute [self] + [y] for xor-sharing"""
        return circuit.add(self, y)

    def __setitem__(self, index, value):
        """Set tensor values by index"""
        if torch.is_tensor(value) or isinstance(value, list):
            value = BinarySharedTensor(value)
        assert isinstance(
            value, BinarySharedTensor
        ), "Unsupported input type %s for __setitem__" % type(value)
        self.share.__setitem__(index, value.share)

    @staticmethod
    def stack(seq, *args, **kwargs):
        """Stacks a list of tensors along a given dimension"""
        assert isinstance(seq, list), "Stack input must be a list"
        assert isinstance(
            seq[0], BinarySharedTensor
        ), "Sequence must contain BinarySharedTensors"
        result = seq[0].shallow_copy()
        result.share = torch.stack(
            [BinarySharedTensor.share for BinarySharedTensor in seq], *args, **kwargs
        )
        return result

    def sum(self, dim=None):
        """Add all tensors along a given dimension using a log-reduction"""
        if dim is None:
            x = self.flatten()
        else:
            x = self.transpose(0, dim)

        # Add all BinarySharedTensors
        while x.size(0) > 1:
            extra = None
            if x.size(0) % 2 == 1:
                extra = x[0]
                x = x[1:]
            x0 = x[: (x.size(0) // 2)]
            x1 = x[(x.size(0) // 2) :]
            x = x0 + x1
            if extra is not None:
                x.share = torch.cat([x.share, extra.share.unsqueeze(0)])

        if dim is None:
            x = x.squeeze()
        else:
            x = x.transpose(0, dim).squeeze(dim)
        return x

    def cumsum(self, *args, **kwargs):
        raise NotImplementedError("BinarySharedTensor cumsum not implemented")

    def trace(self, *args, **kwargs):
        raise NotImplementedError("BinarySharedTensor trace not implemented")

    @staticmethod
    def reveal_batch(tensor_or_list, dst=None):
        """Get (batched) plaintext without any downscaling"""
        if isinstance(tensor_or_list, BinarySharedTensor):
            return tensor_or_list.reveal(dst=dst)

        assert isinstance(
            tensor_or_list, list
        ), f"Invalid input type into reveal {type(tensor_or_list)}"
        shares = [tensor.share for tensor in tensor_or_list]
        op = torch.distributed.ReduceOp.BXOR
        if dst is None:
            return comm.get().all_reduce(shares, op=op, batched=True)
        else:
            return comm.get().reduce(shares, dst=dst, op=op, batched=True)

    def reveal(self, dst=None):
        """Get plaintext without any downscaling"""
        op = torch.distributed.ReduceOp.BXOR
        if dst is None:
            return comm.get().all_reduce(self.share, op=op)
        else:
            return comm.get().reduce(self.share, dst=dst, op=op)

    def get_plain_text(self, dst=None):
        """Decrypts the tensor."""
        # Edge case where share becomes 0 sized (e.g. result of split)
        if self.nelement() < 1:
            return torch.empty(self.share.size())
        return self.encoder.decode(self.reveal(dst=dst))

    def where(self, condition, y):
        """Selects elements from self or y based on condition

        Args:
            condition (torch.bool or BinarySharedTensor): when True yield self,
                otherwise yield y. Note condition is not bitwise.
            y (torch.tensor or BinarySharedTensor): selected when condition is
                False.

        Returns: BinarySharedTensor or torch.tensor.
        """
        if torch.is_tensor(condition):
            condition = condition.long()
            is_binary = ((condition == 1) | (condition == 0)).all()
            assert is_binary, "condition values must be 0 or 1"
            # -1 mult expands 0 into binary 00...00 and 1 into 11...11
            condition_expanded = -condition
            y_masked = y & (~condition_expanded)
        elif isinstance(condition, BinarySharedTensor):
            condition_expanded = condition.clone()
            # -1 mult expands binary while & 1 isolates first bit
            condition_expanded.share = -(condition_expanded.share & 1)
            # encrypted tensor must be first operand
            y_masked = (~condition_expanded) & y
        else:
            msg = f"condition {condition} must be torch.bool, or BinarySharedTensor"
            raise ValueError(msg)

        return (self & condition_expanded) ^ y_masked

    def scatter_(self, dim, index, src):
        """Writes all values from the tensor `src` into `self` at the indices
        specified in the `index` tensor. For each value in `src`, its output index
        is specified by its index in `src` for `dimension != dim` and by the
        corresponding value in `index` for `dimension = dim`.
        """
        if torch.is_tensor(src):
            src = BinarySharedTensor(src)
        assert isinstance(
            src, BinarySharedTensor
        ), "Unrecognized scatter src type: %s" % type(src)
        self.share.scatter_(dim, index, src.share)
        return self

    def scatter(self, dim, index, src):
        """Writes all values from the tensor `src` into `self` at the indices
        specified in the `index` tensor. For each value in `src`, its output index
        is specified by its index in `src` for `dimension != dim` and by the
        corresponding value in `index` for `dimension = dim`.
        """
        result = self.clone()
        return result.scatter_(dim, index, src)

    # Bitwise operators
    __add__ = add
    __lshift__ = lshift
    __rshift__ = rshift

    # In-place bitwise operators
    __ilshift__ = lshift_
    __irshift__ = rshift_

    # Reversed boolean operations
    __radd__ = __add__
    __rxor__ = __xor__
    __rand__ = __and__
    __ror__ = __or__
Esempio n. 11
0
    def test_encode_decode(self):
        """Tests tensor encoding and decoding."""
        for float in [False, True]:
            if float:
                fpe = FixedPointEncoder(precision_bits=16)
            else:
                fpe = FixedPointEncoder(precision_bits=0)
            tensor = get_test_tensor(float=float)
            decoded = fpe.decode(fpe.encode(tensor))
            self._check(
                decoded,
                tensor,
                "Encoding/decoding a %s failed." %
                "float" if float else "long",
            )

        # Make sure encoding a subclass of CrypTensor is a no-op
        crypten.mpc.set_default_provider(
            crypten.mpc.provider.TrustedFirstParty)
        crypten.init()

        tensor = get_test_tensor(float=True)
        encrypted_tensor = crypten.cryptensor(tensor)
        encrypted_tensor = fpe.encode(encrypted_tensor)
        self._check(
            encrypted_tensor.get_plain_text(),
            tensor,
            "Encoding an EncryptedTensor failed.",
        )

        # Try a few other types.
        fpe = FixedPointEncoder(precision_bits=0)
        for dtype in [torch.uint8, torch.int8, torch.int16]:
            tensor = torch.zeros(5, dtype=dtype).random_()
            decoded = fpe.decode(fpe.encode(tensor)).type(dtype)
            self._check(decoded, tensor,
                        "Encoding/decoding a %s failed." % dtype)
Esempio n. 12
0
class ArithmeticSharedTensor(CrypTensor):
    """
        Encrypted tensor object that uses additive sharing to perform computations.

        Additive shares are computed by splitting each value of the input tensor
        into n separate random values that add to the input tensor, where n is
        the number of parties present in the protocol (world_size).
    """

    # constructors:
    def __init__(self, tensor=None, size=None, precision=None, src=0):
        if src == SENTINEL:
            return
        assert (
            isinstance(src, int) and src >= 0 and src < comm.get().get_world_size()
        ), "invalid tensor source"

        self.encoder = FixedPointEncoder(precision_bits=precision)
        if tensor is not None:
            if is_int_tensor(tensor) and precision != 0:
                tensor = tensor.float()
            tensor = self.encoder.encode(tensor)
            size = tensor.size()

        # Generate psuedo-random sharing of zero (PRZS) and add source's tensor
        self.share = ArithmeticSharedTensor.PRZS(size).share
        if self.rank == src:
            assert tensor is not None, "Source must provide a data tensor"
            if hasattr(tensor, "src"):
                assert (
                    tensor.src == src
                ), "Source of data tensor must match source of encryption"
            self.share += tensor

    @property
    def share(self):
        """Returns underlying _tensor"""
        return self._tensor

    @share.setter
    def share(self, value):
        """Sets _tensor to value"""
        self._tensor = value

    @staticmethod
    def from_shares(share, precision=None, src=0):
        """Generate an ArithmeticSharedTensor from a share from each party"""
        result = ArithmeticSharedTensor(src=SENTINEL)
        result.share = share
        result.encoder = FixedPointEncoder(precision_bits=precision)
        return result

    @staticmethod
    def PRZS(*size):
        """
        Generate a Pseudo-random Sharing of Zero (using arithmetic shares)

        This function does so by generating `n` numbers across `n` parties with
        each number being held by exactly 2 parties. One of these parties adds
        this number while the other subtracts this number.
        """
        tensor = ArithmeticSharedTensor(src=SENTINEL)
        current_share = generate_random_ring_element(*size, generator=comm.get().g0)
        next_share = generate_random_ring_element(*size, generator=comm.get().g1)
        tensor.share = current_share - next_share
        return tensor

    @property
    def rank(self):
        return comm.get().get_rank()

    def shallow_copy(self):
        """Create a shallow copy"""
        result = ArithmeticSharedTensor(src=SENTINEL)
        result.encoder = self.encoder
        result.share = self.share
        return result

    def __repr__(self):
        return f"ArithmeticSharedTensor({self.share})"

    def __bool__(self):
        """Override bool operator since encrypted tensors cannot evaluate"""
        raise RuntimeError("Cannot evaluate ArithmeticSharedTensors to boolean values")

    def __nonzero__(self):
        """__bool__ for backwards compatibility with Python 2"""
        raise RuntimeError("Cannot evaluate ArithmeticSharedTensors to boolean values")

    def __setitem__(self, index, value):
        """Set tensor values by index"""
        if isinstance(value, (int, float)) or torch.is_tensor(value):
            value = ArithmeticSharedTensor(value)
        assert isinstance(
            value, ArithmeticSharedTensor
        ), "Unsupported input type %s for __setitem__" % type(value)
        self.share.__setitem__(index, value.share)

    def pad(self, pad, mode="constant", value=0):
        """
            Pads the input tensor with values provided in `value`.
        """
        assert mode == "constant", (
            "Padding with mode %s is currently unsupported" % mode
        )

        result = self.shallow_copy()
        if isinstance(value, (int, float)):
            value = self.encoder.encode(value).item()
            if result.rank == 0:
                result.share = torch.nn.functional.pad(
                    result.share, pad, mode=mode, value=value
                )
            else:
                result.share = torch.nn.functional.pad(
                    result.share, pad, mode=mode, value=0
                )
        elif isinstance(value, ArithmeticSharedTensor):
            assert (
                value.dim() == 0
            ), "Private values used for padding must be 0-dimensional"
            value = value.share.item()
            result.share = torch.nn.functional.pad(
                result.share, pad, mode=mode, value=value
            )
        else:
            raise TypeError(
                "Cannot pad ArithmeticSharedTensor with a %s value" % type(value)
            )

        return result

    @staticmethod
    def stack(tensors, *args, **kwargs):
        """Perform tensor stacking"""
        for i, tensor in enumerate(tensors):
            if torch.is_tensor(tensor):
                tensors[i] = ArithmeticSharedTensor(tensor)
            assert isinstance(
                tensors[i], ArithmeticSharedTensor
            ), "Can't stack %s with ArithmeticSharedTensor" % type(tensor)

        result = tensors[0].shallow_copy()
        result.share = torch.stack(
            [tensor.share for tensor in tensors], *args, **kwargs
        )
        return result

    def reveal(self, dst=None):
        """Get plaintext without any downscaling"""
        tensor = self.share.clone()
        if dst is None:
            return comm.get().all_reduce(tensor)
        else:
            return comm.get().reduce(tensor, dst=dst)

    def get_plain_text(self, dst=None):
        """Decrypt the tensor"""
        # Edge case where share becomes 0 sized (e.g. result of split)
        if self.nelement() < 1:
            return torch.empty(self.share.size())
        return self.encoder.decode(self.reveal(dst=dst))

    def _arithmetic_function_(self, y, op, *args, **kwargs):
        return self._arithmetic_function(y, op, inplace=True, *args, **kwargs)

    def _arithmetic_function(self, y, op, inplace=False, *args, **kwargs):
        assert op in [
            "add",
            "sub",
            "mul",
            "matmul",
            "conv2d",
            "conv_transpose2d",
        ], f"Provided op `{op}` is not a supported arithmetic function"

        additive_func = op in ["add", "sub"]
        public = isinstance(y, (int, float)) or torch.is_tensor(y)
        private = isinstance(y, ArithmeticSharedTensor)

        if inplace:
            result = self
            if additive_func or (op == "mul" and public):
                op += "_"
        else:
            result = self.clone()

        if public:
            y = result.encoder.encode(y)

            if additive_func:  # ['add', 'sub']
                if result.rank == 0:
                    result.share = getattr(result.share, op)(y)
                else:
                    result.share = torch.broadcast_tensors(result.share, y)[0]
            elif op == "mul_":  # ['mul_']
                result.share = result.share.mul_(y)
            else:  # ['mul', 'matmul', 'conv2d', 'conv_transpose2d']
                result.share = getattr(torch, op)(result.share, y, *args, **kwargs)
        elif private:
            if additive_func:  # ['add', 'sub', 'add_', 'sub_']
                result.share = getattr(result.share, op)(y.share)
            else:  # ['mul', 'matmul', 'conv2d', 'conv_transpose2d']
                # NOTE: 'mul_' calls 'mul' here
                # Must copy _tensor.data here to support 'mul_' being inplace
                result.share.data = getattr(beaver, op)(
                    result, y, *args, **kwargs
                ).share.data
        else:
            raise TypeError("Cannot %s %s with %s" % (op, type(y), type(self)))

        # Scale by encoder scale if necessary
        if not additive_func:
            if public:  # scale by self.encoder.scale
                if self.encoder.scale > 1:
                    return result.div_(result.encoder.scale)
                else:
                    result.encoder = self.encoder
            else:  # scale by larger of self.encoder.scale and y.encoder.scale
                if self.encoder.scale > 1 and y.encoder.scale > 1:
                    return result.div_(result.encoder.scale)
                elif self.encoder.scale > 1:
                    result.encoder = self.encoder
                else:
                    result.encoder = y.encoder

        return result

    def add(self, y):
        """Perform element-wise addition"""
        return self._arithmetic_function(y, "add")

    def add_(self, y):
        """Perform element-wise addition"""
        return self._arithmetic_function_(y, "add")

    def sub(self, y):
        """Perform element-wise subtraction"""
        return self._arithmetic_function(y, "sub")

    def sub_(self, y):
        """Perform element-wise subtraction"""
        return self._arithmetic_function_(y, "sub")

    def mul(self, y):
        """Perform element-wise multiplication"""
        if isinstance(y, int) or is_int_tensor(y):
            result = self.clone()
            result.share = self.share * y
            return result
        return self._arithmetic_function(y, "mul")

    def mul_(self, y):
        """Perform element-wise multiplication"""
        if isinstance(y, int) or is_int_tensor(y):
            self.share *= y
            return self
        return self._arithmetic_function_(y, "mul")

    def div(self, y):
        """Divide by a given tensor"""
        result = self.clone()
        if isinstance(y, CrypTensor):
            result.share = torch.broadcast_tensors(result.share, y.share)[0].clone()
        elif torch.is_tensor(y):
            result.share = torch.broadcast_tensors(result.share, y)[0].clone()
        return result.div_(y)

    def div_(self, y):
        """Divide two tensors element-wise"""
        # TODO: Add test coverage for this code path (next 4 lines)
        if isinstance(y, float) and int(y) == y:
            y = int(y)
        if is_float_tensor(y) and y.frac().eq(0).all():
            y = y.long()

        if isinstance(y, int) or is_int_tensor(y):
            # Truncate protocol for dividing by public integers:
            if comm.get().get_world_size() > 2:
                wraps = self.wraps()
                self.share /= y
                # NOTE: The multiplication here must be split into two parts
                # to avoid long out-of-bounds when y <= 2 since (2 ** 63) is
                # larger than the largest long integer.
                self -= wraps * 4 * (int(2 ** 62) // y)
            else:
                self.share /= y
            return self

        # Otherwise multiply by reciprocal
        if isinstance(y, float):
            y = torch.FloatTensor([y])

        assert is_float_tensor(y), "Unsupported type for div_: %s" % type(y)
        return self.mul_(y.reciprocal())

    def wraps(self):
        """Privately computes the number of wraparounds for a set a shares"""
        return beaver.wraps(self)

    def matmul(self, y):
        """Perform matrix multiplication using some tensor"""
        return self._arithmetic_function(y, "matmul")

    def prod(self, dim=None, keepdim=False):
        """
        Returns the product of each row of the `input` tensor in the given
        dimension `dim`.

        If `keepdim` is `True`, the output tensor is of the same size as `input`
        except in the dimension `dim` where it is of size 1. Otherwise, `dim` is
        squeezed, resulting in the output tensor having 1 fewer dimension than
        `input`.
        """
        if dim is None:
            return self.flatten().prod(dim=0)

        result = self.clone()
        while result.size(dim) > 1:
            size = result.size(dim)
            x, y, remainder = result.split([size // 2, size // 2, size % 2], dim=dim)
            result = x.mul_(y)
            result.share = torch.cat([result.share, remainder.share], dim=dim)

        # Squeeze result if necessary
        if not keepdim:
            result.share = result.share.squeeze(dim)
        return result

    def mean(self, *args, **kwargs):
        """Computes mean of given tensor"""
        result = self.sum(*args, **kwargs)

        # Handle special case where input has 0 dimensions
        if self.dim() == 0:
            return result

        # Compute divisor to use to compute mean
        size = self.size()
        if len(args) > 0:  # dimension is specified
            dims = [args[0]] if isinstance(args[0], int) else args[0]
            size = [size[dim] for dim in dims]
        assert len(size) > 0, "cannot reduce over zero dimensions"
        divisor = reduce(lambda x, y: x * y, size)

        return result.div(divisor)

    def var(self, *args, **kwargs):
        """Computes variance of tensor along specified dimensions."""
        if len(args) > 0:  # dimension is specified
            mean = self.mean(*args, **{"keepdim": True})
        else:
            mean = self.mean()
        result = (self - mean).square().sum(*args, **kwargs)
        size = self.size()
        if len(args) > 0:  # dimension is specified
            dims = [args[0]] if isinstance(args[0], int) else args[0]
            size = [size[dim] for dim in dims]
        assert len(size) > 0, "cannot reduce over zero dimensions"
        divisor = reduce(lambda x, y: x * y, size)
        return result.div(divisor)

    def conv2d(self, kernel, **kwargs):
        """Perform a 2D convolution using the given kernel"""
        return self._arithmetic_function(kernel, "conv2d", **kwargs)

    def conv_transpose2d(self, kernel, **kwargs):
        """Perform a 2D transpose convolution (deconvolution) using the given kernel"""
        return self._arithmetic_function(kernel, "conv_transpose2d", **kwargs)

    def index_add(self, dim, index, tensor):
        """Perform out-of-place index_add: Accumulate the elements of tensor into the
        self tensor by adding to the indices in the order given in index. """
        result = self.clone()
        return result.index_add_(dim, index, tensor)

    def index_add_(self, dim, index, tensor):
        """Perform in-place index_add: Accumulate the elements of tensor into the
        self tensor by adding to the indices in the order given in index. """
        public = isinstance(tensor, (int, float)) or torch.is_tensor(tensor)
        private = isinstance(tensor, ArithmeticSharedTensor)
        if public:
            enc_tensor = self.encoder.encode(tensor)
            if self.rank == 0:
                self._tensor.index_add_(dim, index, enc_tensor)
        elif private:
            self._tensor.index_add_(dim, index, tensor._tensor)
        else:
            raise TypeError("index_add second tensor of unsupported type")
        return self

    def scatter_add(self, dim, index, other):
        """Adds all values from the tensor other into self at the indices
        specified in the index tensor in a similar fashion as scatter_(). For
        each value in other, it is added to an index in self which is specified
        by its index in other for dimension != dim and by the corresponding
        value in index for dimension = dim.
        """
        return self.clone().scatter_add_(dim, index, other)

    def scatter_add_(self, dim, index, other):
        """Adds all values from the tensor other into self at the indices
        specified in the index tensor in a similar fashion as scatter_(). For
        each value in other, it is added to an index in self which is specified
        by its index in other for dimension != dim and by the corresponding
        value in index for dimension = dim.
        """
        public = isinstance(other, (int, float)) or torch.is_tensor(other)
        private = isinstance(other, CrypTensor)
        if public:
            if self.rank == 0:
                self.share.scatter_add_(dim, index, self.encoder.encode(other))
        elif private:
            self.share.scatter_add_(dim, index, other.share)
        else:
            raise TypeError("scatter_add second tensor of unsupported type")
        return self

    def avg_pool2d(self, kernel_size, *args, **kwargs):
        """Perform an average pooling on each 2D matrix of the given tensor

        Args:
            kernel_size (int or tuple): pooling kernel size.
        """
        z = self.sum_pool2d(kernel_size, *args, **kwargs)
        if isinstance(kernel_size, (int, float)):
            pool_size = kernel_size ** 2
        else:
            pool_size = kernel_size[0] * kernel_size[1]
        return z / pool_size

    def sum_pool2d(self, *args, **kwargs):
        """Perform a sum pooling on each 2D matrix of the given tensor"""
        result = self.shallow_copy()
        result.share = torch.nn.functional.avg_pool2d(
            self.share, *args, **kwargs, divisor_override=1
        )
        return result

    def take(self, index, dimension=None):
        """Take entries of tensor along a dimension according to the index.
            This function is identical to torch.take() when dimension=None,
            otherwise, it is identical to ONNX gather() function.
        """
        result = self.shallow_copy()
        index = index.long()
        if dimension is None:
            result.share = torch.take(self.share, index)
        else:
            all_indices = [slice(0, x) for x in self.size()]
            all_indices[dimension] = index
            result.share = self.share[all_indices]
        return result

    # negation and reciprocal:
    def neg_(self):
        """Negate the tensor's values"""
        self.share.neg_()
        return self

    def neg(self):
        """Negate the tensor's values"""
        return self.clone().neg_()

    def square(self):
        result = self.clone()
        result.share = beaver.square(self).div_(self.encoder.scale).share
        return result

    # copy between CPU and GPU:
    def cuda(self):
        raise NotImplementedError("CUDA is not supported for ArithmeticSharedTensors")

    def cpu(self):
        raise NotImplementedError("CUDA is not supported for ArithmeticSharedTensors")

    def dot(self, y, weights=None):
        """Compute a dot product between two tensors"""
        assert self.size() == y.size(), "Number of elements do not match"
        if weights is not None:
            assert weights.size() == self.size(), "Incorrect number of weights"
            result = self * weights
        else:
            result = self.clone()

        return result.mul_(y).sum()

    def ger(self, y):
        """Computer an outer product between two vectors"""
        assert self.dim() == 1 and y.dim() == 1, "Outer product must be on 1D tensors"
        return self.view((-1, 1)).matmul(y.view((1, -1)))

    def where(self, condition, y):
        """Selects elements from self or y based on condition

        Args:
            condition (torch.bool or ArithmeticSharedTensor): when True
                yield self, otherwise yield y.
            y (torch.tensor or ArithmeticSharedTensor): values selected at
                indices where condition is False.

        Returns: ArithmeticSharedTensor or torch.tensor
        """
        if torch.is_tensor(condition):
            condition = condition.float()
            y_masked = y * (1 - condition)
        else:
            # encrypted tensor must be first operand
            y_masked = (1 - condition) * y

        return self * condition + y_masked

    def scatter_(self, dim, index, src):
        """Writes all values from the tensor `src` into `self` at the indices
        specified in the `index` tensor. For each value in `src`, its output index
        is specified by its index in `src` for `dimension != dim` and by the
        corresponding value in `index` for `dimension = dim`.
        """
        if torch.is_tensor(src):
            src = ArithmeticSharedTensor(src)
        assert isinstance(
            src, ArithmeticSharedTensor
        ), "Unrecognized scatter src type: %s" % type(src)
        self.share.scatter_(dim, index, src.share)
        return self

    def scatter(self, dim, index, src):
        """Writes all values from the tensor `src` into `self` at the indices
        specified in the `index` tensor. For each value in `src`, its output index
        is specified by its index in `src` for `dimension != dim` and by the
        corresponding value in `index` for `dimension = dim`.
        """
        result = self.clone()
        return result.scatter_(dim, index, src)
Esempio n. 13
0
class ArithmeticSharedTensor(object):
    """
    Encrypted tensor object that uses additive sharing to perform computations.

    Additive shares are computed by splitting each value of the input tensor
    into n separate random values that add to the input tensor, where n is
    the number of parties present in the protocol (world_size).
    """

    # constructors:
    def __init__(
        self,
        tensor=None,
        size=None,
        broadcast_size=False,
        precision=None,
        src=0,
        device=None,
    ):
        """
        Creates the shared tensor from the input `tensor` provided by party `src`.

        The other parties can specify a `tensor` or `size` to determine the size
        of the shared tensor object to create. In this case, all parties must
        specify the same (tensor) size to prevent the party's shares from varying
        in size, which leads to undefined behavior.

        Alternatively, the parties can set `broadcast_size` to `True` to have the
        `src` party broadcast the correct size. The parties who do not know the
        tensor size beforehand can provide an empty tensor as input. This is
        guaranteed to produce correct behavior but requires an additional
        communication round.

        The parties can also set the `precision` and `device` for their share of
        the tensor. If `device` is unspecified, it is set to `tensor.device`.
        """

        # do nothing if source is sentinel:
        if src == SENTINEL:
            return

        # assertions on inputs:
        assert (isinstance(src, int) and src >= 0
                and src < comm.get().get_world_size()
                ), "specified source party does not exist"
        if self.rank == src:
            assert tensor is not None, "source must provide a data tensor"
            if hasattr(tensor, "src"):
                assert (
                    tensor.src == src
                ), "source of data tensor must match source of encryption"
        if not broadcast_size:
            assert (tensor is not None or size is not None
                    ), "must specify tensor or size, or set broadcast_size"

        # if device is unspecified, try and get it from tensor:
        if device is None and tensor is not None and hasattr(tensor, "device"):
            device = tensor.device

        # encode the input tensor:
        self.encoder = FixedPointEncoder(precision_bits=precision)
        if tensor is not None:
            if is_int_tensor(tensor) and precision != 0:
                tensor = tensor.float()
            tensor = self.encoder.encode(tensor)
            tensor = tensor.to(device=device)
            size = tensor.size()

        # if other parties do not know tensor's size, broadcast the size:
        if broadcast_size:
            size = comm.get().broadcast_obj(size, src)

        # generate pseudo-random zero sharing (PRZS) and add source's tensor:
        self.share = ArithmeticSharedTensor.PRZS(size, device=device).share
        if self.rank == src:
            self.share += tensor

    @staticmethod
    def new(*args, **kwargs):
        """
        Creates a new ArithmeticSharedTensor, passing all args and kwargs into the constructor.
        """
        return ArithmeticSharedTensor(*args, **kwargs)

    @property
    def device(self):
        """Return the `torch.device` of the underlying _tensor"""
        return self._tensor.device

    @property
    def is_cuda(self):
        """Return True if the underlying _tensor is stored on GPU, False otherwise"""
        return self._tensor.is_cuda

    def to(self, *args, **kwargs):
        """Call `torch.Tensor.to` on the underlying _tensor"""
        self._tensor = self._tensor.to(*args, **kwargs)
        return self

    def cuda(self, *args, **kwargs):
        """Call `torch.Tensor.cuda` on the underlying _tensor"""
        self._tensor = CUDALongTensor(self._tensor.cuda(*args, **kwargs))
        return self

    def cpu(self, *args, **kwargs):
        """Call `torch.Tensor.cpu` on the underlying _tensor"""
        self._tensor = self._tensor.cpu(*args, **kwargs)
        return self

    @property
    def share(self):
        """Returns underlying _tensor"""
        return self._tensor

    @share.setter
    def share(self, value):
        """Sets _tensor to value"""
        self._tensor = value

    @staticmethod
    def from_shares(share, precision=None, device=None):
        """Generate an ArithmeticSharedTensor from a share from each party"""
        result = ArithmeticSharedTensor(src=SENTINEL)
        share = share.to(device) if device is not None else share
        result.share = CUDALongTensor(share) if share.is_cuda else share
        result.encoder = FixedPointEncoder(precision_bits=precision)
        return result

    @staticmethod
    def PRZS(*size, device=None):
        """
        Generate a Pseudo-random Sharing of Zero (using arithmetic shares)

        This function does so by generating `n` numbers across `n` parties with
        each number being held by exactly 2 parties. One of these parties adds
        this number while the other subtracts this number.
        """
        from crypten import generators

        tensor = ArithmeticSharedTensor(src=SENTINEL)
        if device is None:
            device = torch.device("cpu")
        elif isinstance(device, str):
            device = torch.device(device)
        g0 = generators["prev"][device]
        g1 = generators["next"][device]
        current_share = generate_random_ring_element(*size,
                                                     generator=g0,
                                                     device=device)
        next_share = generate_random_ring_element(*size,
                                                  generator=g1,
                                                  device=device)
        tensor.share = current_share - next_share
        return tensor

    @staticmethod
    def PRSS(*size, device=None):
        """
        Generates a Pseudo-random Secret Share from a set of random arithmetic shares
        """
        share = generate_random_ring_element(*size, device=device)
        tensor = ArithmeticSharedTensor.from_shares(share=share)
        return tensor

    @property
    def rank(self):
        return comm.get().get_rank()

    def shallow_copy(self):
        """Create a shallow copy"""
        result = ArithmeticSharedTensor(src=SENTINEL)
        result.encoder = self.encoder
        result._tensor = self._tensor
        return result

    def clone(self):
        result = ArithmeticSharedTensor(src=SENTINEL)
        result.encoder = self.encoder
        result._tensor = self._tensor.clone()
        return result

    def copy_(self, other):
        """Copies other tensor into this tensor."""
        self.share.copy_(other.share)
        self.encoder = other.encoder

    def __repr__(self):
        return f"ArithmeticSharedTensor({self.share})"

    def __bool__(self):
        """Override bool operator since encrypted tensors cannot evaluate"""
        raise RuntimeError(
            "Cannot evaluate ArithmeticSharedTensors to boolean values")

    def __nonzero__(self):
        """__bool__ for backwards compatibility with Python 2"""
        raise RuntimeError(
            "Cannot evaluate ArithmeticSharedTensors to boolean values")

    def __setitem__(self, index, value):
        """Set tensor values by index"""
        if isinstance(value, (int, float)) or is_tensor(value):
            value = ArithmeticSharedTensor(value)
        assert isinstance(
            value, ArithmeticSharedTensor
        ), "Unsupported input type %s for __setitem__" % type(value)
        self.share.__setitem__(index, value.share)

    def pad(self, pad, mode="constant", value=0):
        """
        Pads the input tensor with values provided in `value`.
        """
        assert mode == "constant", (
            "Padding with mode %s is currently unsupported" % mode)

        result = self.shallow_copy()
        if isinstance(value, (int, float)):
            value = self.encoder.encode(value).item()
            if result.rank == 0:
                result.share = torch.nn.functional.pad(result.share,
                                                       pad,
                                                       mode=mode,
                                                       value=value)
            else:
                result.share = torch.nn.functional.pad(result.share,
                                                       pad,
                                                       mode=mode,
                                                       value=0)
        elif isinstance(value, ArithmeticSharedTensor):
            assert (value.dim() == 0
                    ), "Private values used for padding must be 0-dimensional"
            value = value.share.item()
            result.share = torch.nn.functional.pad(result.share,
                                                   pad,
                                                   mode=mode,
                                                   value=value)
        else:
            raise TypeError(
                "Cannot pad ArithmeticSharedTensor with a %s value" %
                type(value))

        return result

    @staticmethod
    def stack(tensors, *args, **kwargs):
        """Perform tensor stacking"""
        for i, tensor in enumerate(tensors):
            if is_tensor(tensor):
                tensors[i] = ArithmeticSharedTensor(tensor)
            assert isinstance(
                tensors[i], ArithmeticSharedTensor
            ), "Can't stack %s with ArithmeticSharedTensor" % type(tensor)

        result = tensors[0].shallow_copy()
        result.share = torch_stack([tensor.share for tensor in tensors], *args,
                                   **kwargs)
        return result

    @staticmethod
    def reveal_batch(tensor_or_list, dst=None):
        """Get (batched) plaintext without any downscaling"""
        if isinstance(tensor_or_list, ArithmeticSharedTensor):
            return tensor_or_list.reveal(dst=dst)

        assert isinstance(
            tensor_or_list,
            list), f"Invalid input type into reveal {type(tensor_or_list)}"
        shares = [tensor.share for tensor in tensor_or_list]
        if dst is None:
            return comm.get().all_reduce(shares, batched=True)
        else:
            return comm.get().reduce(shares, dst, batched=True)

    def reveal(self, dst=None):
        """Decrypts the tensor without any downscaling."""
        tensor = self.share.clone()
        if dst is None:
            return comm.get().all_reduce(tensor)
        else:
            return comm.get().reduce(tensor, dst)

    def get_plain_text(self, dst=None):
        """Decrypts the tensor."""
        # Edge case where share becomes 0 sized (e.g. result of split)
        if self.nelement() < 1:
            return torch.empty(self.share.size())
        return self.encoder.decode(self.reveal(dst=dst))

    def encode_(self, new_encoder):
        """Rescales the input to a new encoding in-place"""
        if self.encoder.scale == new_encoder.scale:
            return self
        elif self.encoder.scale < new_encoder.scale:
            scale_factor = new_encoder.scale // self.encoder.scale
            self.share *= scale_factor
        else:
            scale_factor = self.encoder.scale // new_encoder.scale
            self = self.div_(scale_factor)
        self.encoder = new_encoder
        return self

    def encode(self, new_encoder):
        """Rescales the input to a new encoding"""
        return self.clone().encode_(new_encoder)

    def encode_as_(self, other):
        """Rescales self to have the same encoding as other"""
        return self.encode_(other.encoder)

    def encode_as(self, other):
        return self.encode(other.encoder)

    def _arithmetic_function_(self, y, op, *args, **kwargs):
        return self._arithmetic_function(y, op, inplace=True, *args, **kwargs)

    def _arithmetic_function(self,
                             y,
                             op,
                             inplace=False,
                             *args,
                             **kwargs):  # noqa:C901
        assert op in [
            "add",
            "sub",
            "mul",
            "matmul",
            "conv1d",
            "conv2d",
            "conv_transpose1d",
            "conv_transpose2d",
        ], f"Provided op `{op}` is not a supported arithmetic function"

        additive_func = op in ["add", "sub"]
        public = isinstance(y, (int, float)) or is_tensor(y)
        private = isinstance(y, ArithmeticSharedTensor)

        if inplace:
            result = self
            if additive_func or (op == "mul" and public):
                op += "_"
        else:
            result = self.clone()

        if public:
            y = result.encoder.encode(y, device=self.device)

            if additive_func:  # ['add', 'sub']
                if result.rank == 0:
                    result.share = getattr(result.share, op)(y)
                else:
                    result.share = torch.broadcast_tensors(result.share, y)[0]
            elif op == "mul_":  # ['mul_']
                result.share = result.share.mul_(y)
            else:  # ['mul', 'matmul', 'convNd', 'conv_transposeNd']
                result.share = getattr(torch, op)(result.share, y, *args,
                                                  **kwargs)
        elif private:
            if additive_func:  # ['add', 'sub', 'add_', 'sub_']
                # Re-encode if necessary:
                if self.encoder.scale > y.encoder.scale:
                    y.encode_as_(result)
                elif self.encoder.scale < y.encoder.scale:
                    result.encode_as_(y)
                result.share = getattr(result.share, op)(y.share)
            else:  # ['mul', 'matmul', 'convNd', 'conv_transposeNd']
                protocol = globals()[cfg.mpc.protocol]
                result.share.set_(
                    getattr(protocol, op)(result, y, *args,
                                          **kwargs).share.data)
        else:
            raise TypeError("Cannot %s %s with %s" % (op, type(y), type(self)))

        # Scale by encoder scale if necessary
        if not additive_func:
            if public:  # scale by self.encoder.scale
                if self.encoder.scale > 1:
                    return result.div_(result.encoder.scale)
                else:
                    result.encoder = self.encoder
            else:  # scale by larger of self.encoder.scale and y.encoder.scale
                if self.encoder.scale > 1 and y.encoder.scale > 1:
                    return result.div_(result.encoder.scale)
                elif self.encoder.scale > 1:
                    result.encoder = self.encoder
                else:
                    result.encoder = y.encoder

        return result

    def add(self, y):
        """Perform element-wise addition"""
        return self._arithmetic_function(y, "add")

    def add_(self, y):
        """Perform element-wise addition"""
        return self._arithmetic_function_(y, "add")

    def sub(self, y):
        """Perform element-wise subtraction"""
        return self._arithmetic_function(y, "sub")

    def sub_(self, y):
        """Perform element-wise subtraction"""
        return self._arithmetic_function_(y, "sub")

    def mul(self, y):
        """Perform element-wise multiplication"""
        if isinstance(y, int):
            result = self.clone()
            result.share = self.share * y
            return result
        return self._arithmetic_function(y, "mul")

    def mul_(self, y):
        """Perform element-wise multiplication"""
        if isinstance(y, int) or is_int_tensor(y):
            self.share *= y
            return self
        return self._arithmetic_function_(y, "mul")

    def div(self, y):
        """Divide by a given tensor"""
        result = self.clone()
        if isinstance(y, CrypTensor):
            result.share = torch.broadcast_tensors(result.share,
                                                   y.share)[0].clone()
        elif is_tensor(y):
            result.share = torch.broadcast_tensors(result.share, y)[0].clone()
        return result.div_(y)

    def div_(self, y):
        """Divide two tensors element-wise"""
        # TODO: Add test coverage for this code path (next 4 lines)
        if isinstance(y, float) and int(y) == y:
            y = int(y)
        if is_float_tensor(y) and y.frac().eq(0).all():
            y = y.long()

        if isinstance(y, int) or is_int_tensor(y):
            validate = cfg.debug.validation_mode

            if validate:
                tolerance = 1.0
                tensor = self.get_plain_text()

            # Truncate protocol for dividing by public integers:
            if comm.get().get_world_size() > 2:
                protocol = globals()[cfg.mpc.protocol]
                protocol.truncate(self, y)
            else:
                self.share = self.share.div_(y, rounding_mode="trunc")

            # Validate
            if validate:
                if not torch.lt(torch.abs(self.get_plain_text() * y - tensor),
                                tolerance).all():
                    raise ValueError("Final result of division is incorrect.")

            return self

        # Otherwise multiply by reciprocal
        if isinstance(y, float):
            y = torch.tensor([y], dtype=torch.float, device=self.device)

        assert is_float_tensor(y), "Unsupported type for div_: %s" % type(y)
        return self.mul_(y.reciprocal())

    def matmul(self, y):
        """Perform matrix multiplication using some tensor"""
        return self._arithmetic_function(y, "matmul")

    def conv1d(self, kernel, **kwargs):
        """Perform a 1D convolution using the given kernel"""
        return self._arithmetic_function(kernel, "conv1d", **kwargs)

    def conv2d(self, kernel, **kwargs):
        """Perform a 2D convolution using the given kernel"""
        return self._arithmetic_function(kernel, "conv2d", **kwargs)

    def conv_transpose1d(self, kernel, **kwargs):
        """Perform a 1D transpose convolution (deconvolution) using the given kernel"""
        return self._arithmetic_function(kernel, "conv_transpose1d", **kwargs)

    def conv_transpose2d(self, kernel, **kwargs):
        """Perform a 2D transpose convolution (deconvolution) using the given kernel"""
        return self._arithmetic_function(kernel, "conv_transpose2d", **kwargs)

    def index_add(self, dim, index, tensor):
        """Perform out-of-place index_add: Accumulate the elements of tensor into the
        self tensor by adding to the indices in the order given in index."""
        result = self.clone()
        return result.index_add_(dim, index, tensor)

    def index_add_(self, dim, index, tensor):
        """Perform in-place index_add: Accumulate the elements of tensor into the
        self tensor by adding to the indices in the order given in index."""
        public = isinstance(tensor, (int, float)) or is_tensor(tensor)
        private = isinstance(tensor, ArithmeticSharedTensor)
        if public:
            enc_tensor = self.encoder.encode(tensor)
            if self.rank == 0:
                self._tensor.index_add_(dim, index, enc_tensor)
        elif private:
            self._tensor.index_add_(dim, index, tensor._tensor)
        else:
            raise TypeError("index_add second tensor of unsupported type")
        return self

    def scatter_add(self, dim, index, other):
        """Adds all values from the tensor other into self at the indices
        specified in the index tensor in a similar fashion as scatter_(). For
        each value in other, it is added to an index in self which is specified
        by its index in other for dimension != dim and by the corresponding
        value in index for dimension = dim.
        """
        return self.clone().scatter_add_(dim, index, other)

    def scatter_add_(self, dim, index, other):
        """Adds all values from the tensor other into self at the indices
        specified in the index tensor in a similar fashion as scatter_(). For
        each value in other, it is added to an index in self which is specified
        by its index in other for dimension != dim and by the corresponding
        value in index for dimension = dim.
        """
        public = isinstance(other, (int, float)) or is_tensor(other)
        private = isinstance(other, ArithmeticSharedTensor)
        if public:
            if self.rank == 0:
                self.share.scatter_add_(dim, index, self.encoder.encode(other))
        elif private:
            self.share.scatter_add_(dim, index, other.share)
        else:
            raise TypeError("scatter_add second tensor of unsupported type")
        return self

    def avg_pool2d(self, kernel_size, stride=None, padding=0, ceil_mode=False):
        """Perform an average pooling on each 2D matrix of the given tensor

        Args:
            kernel_size (int or tuple): pooling kernel size.
        """
        # TODO: Add check for whether ceil_mode would change size of output and allow ceil_mode when it wouldn't
        if ceil_mode:
            raise NotImplementedError(
                "CrypTen does not support `ceil_mode` for `avg_pool2d`")

        z = self._sum_pool2d(kernel_size,
                             stride=stride,
                             padding=padding,
                             ceil_mode=ceil_mode)
        if isinstance(kernel_size, (int, float)):
            pool_size = kernel_size**2
        else:
            pool_size = kernel_size[0] * kernel_size[1]
        return z / pool_size

    def _sum_pool2d(self,
                    kernel_size,
                    stride=None,
                    padding=0,
                    ceil_mode=False):
        """Perform a sum pooling on each 2D matrix of the given tensor"""
        result = self.shallow_copy()

        result.share = torch.nn.functional.avg_pool2d(
            self.share,
            kernel_size,
            stride=stride,
            padding=padding,
            ceil_mode=ceil_mode,
            divisor_override=1,
        )
        return result

    # negation and reciprocal:
    def neg_(self):
        """Negate the tensor's values"""
        self.share.neg_()
        return self

    def neg(self):
        """Negate the tensor's values"""
        return self.clone().neg_()

    def square_(self):
        protocol = globals()[cfg.mpc.protocol]
        self.share = protocol.square(self).div_(self.encoder.scale).share
        return self

    def square(self):
        return self.clone().square_()

    def where(self, condition, y):
        """Selects elements from self or y based on condition

        Args:
            condition (torch.bool or ArithmeticSharedTensor): when True
                yield self, otherwise yield y.
            y (torch.tensor or ArithmeticSharedTensor): values selected at
                indices where condition is False.

        Returns: ArithmeticSharedTensor or torch.tensor
        """
        if is_tensor(condition):
            condition = condition.float()
            y_masked = y * (1 - condition)
        else:
            # encrypted tensor must be first operand
            y_masked = (1 - condition) * y

        return self * condition + y_masked

    def scatter_(self, dim, index, src):
        """Writes all values from the tensor `src` into `self` at the indices
        specified in the `index` tensor. For each value in `src`, its output index
        is specified by its index in `src` for `dimension != dim` and by the
        corresponding value in `index` for `dimension = dim`.
        """
        if is_tensor(src):
            src = ArithmeticSharedTensor(src)
        assert isinstance(src, ArithmeticSharedTensor
                          ), "Unrecognized scatter src type: %s" % type(src)
        self.share.scatter_(dim, index, src.share)
        return self

    def scatter(self, dim, index, src):
        """Writes all values from the tensor `src` into `self` at the indices
        specified in the `index` tensor. For each value in `src`, its output index
        is specified by its index in `src` for `dimension != dim` and by the
        corresponding value in `index` for `dimension = dim`.
        """
        result = self.clone()
        return result.scatter_(dim, index, src)

    # overload operators:
    __add__ = add
    __iadd__ = add_
    __radd__ = __add__
    __sub__ = sub
    __isub__ = sub_
    __mul__ = mul
    __imul__ = mul_
    __rmul__ = __mul__
    __div__ = div
    __truediv__ = div
    __itruediv__ = div_
    __neg__ = neg

    def __rsub__(self, tensor):
        """Subtracts self from tensor."""
        return -self + tensor

    @property
    def data(self):
        return self._tensor.data

    @data.setter
    def data(self, value):
        self._tensor.set_(value)
Esempio n. 14
0
class BinarySharedTensor(object):
    """
    Encrypted tensor object that uses binary sharing to perform computations.

    Binary shares are computed by splitting each value of the input tensor
    into n separate random values that xor together to the input tensor value,
    where n is the number of parties present in the protocol (world_size).
    """

    def __init__(
        self, tensor=None, size=None, broadcast_size=False, src=0, device=None
    ):
        """
        Creates the shared tensor from the input `tensor` provided by party `src`.

        The other parties can specify a `tensor` or `size` to determine the size
        of the shared tensor object to create. In this case, all parties must
        specify the same (tensor) size to prevent the party's shares from varying
        in size, which leads to undefined behavior.

        Alternatively, the parties can set `broadcast_size` to `True` to have the
        `src` party broadcast the correct size. The parties who do not know the
        tensor size beforehand can provide an empty tensor as input. This is
        guaranteed to produce correct behavior but requires an additional
        communication round.

        The parties can also set the `precision` and `device` for their share of
        the tensor. If `device` is unspecified, it is set to `tensor.device`.
        """

        # do nothing if source is sentinel:
        if src == SENTINEL:
            return

        # assertions on inputs:
        assert (
            isinstance(src, int) and src >= 0 and src < comm.get().get_world_size()
        ), "specified source party does not exist"
        if self.rank == src:
            assert tensor is not None, "source must provide a data tensor"
            if hasattr(tensor, "src"):
                assert (
                    tensor.src == src
                ), "source of data tensor must match source of encryption"
        if not broadcast_size:
            assert (
                tensor is not None or size is not None
            ), "must specify tensor or size, or set broadcast_size"

        # if device is unspecified, try and get it from tensor:
        if device is None and tensor is not None and hasattr(tensor, "device"):
            device = tensor.device

        # assume zero bits of precision unless encoder is set outside of init:
        self.encoder = FixedPointEncoder(precision_bits=0)
        if tensor is not None:
            tensor = self.encoder.encode(tensor)
            tensor = tensor.to(device=device)
            size = tensor.size()

        # if other parties do not know tensor's size, broadcast the size:
        if broadcast_size:
            size = comm.get().broadcast_obj(size, src)

        # generate pseudo-random zero sharing (PRZS) and add source's tensor:
        self.share = BinarySharedTensor.PRZS(size, device=device).share
        if self.rank == src:
            self.share ^= tensor

    @staticmethod
    def new(*args, **kwargs):
        """
        Creates a new BinarySharedTensor, passing all args and kwargs into the constructor.
        """
        return BinarySharedTensor(*args, **kwargs)

    @staticmethod
    def from_shares(share, precision=None, src=0, device=None):
        """Generate a BinarySharedTensor from a share from each party"""
        result = BinarySharedTensor(src=SENTINEL)
        share = share.to(device) if device is not None else share
        result.share = CUDALongTensor(share) if share.is_cuda else share
        result.encoder = FixedPointEncoder(precision_bits=precision)
        return result

    @staticmethod
    def PRZS(*size, device=None):
        """
        Generate a Pseudo-random Sharing of Zero (using arithmetic shares)

        This function does so by generating `n` numbers across `n` parties with
        each number being held by exactly 2 parties. Therefore, each party holds
        two numbers. A zero sharing is found by having each party xor their two
        numbers together.
        """
        from crypten import generators

        tensor = BinarySharedTensor(src=SENTINEL)
        if device is None:
            device = torch.device("cpu")
        elif isinstance(device, str):
            device = torch.device(device)
        g0 = generators["prev"][device]
        g1 = generators["next"][device]
        current_share = generate_kbit_random_tensor(*size, device=device, generator=g0)
        next_share = generate_kbit_random_tensor(*size, device=device, generator=g1)
        tensor.share = current_share ^ next_share
        return tensor

    @staticmethod
    def rand(*size, bits=64, device=None):
        """
        Generate a uniform random samples with a given size.
        """
        tensor = BinarySharedTensor(src=SENTINEL)
        if isinstance(size[0], (torch.Size, tuple)):
            size = size[0]
        tensor.share = generate_kbit_random_tensor(size, bitlength=bits, device=device)
        return tensor

    @property
    def device(self):
        """Return the `torch.device` of the underlying _tensor"""
        return self._tensor.device

    @property
    def is_cuda(self):
        """Return True if the underlying _tensor is stored on GPU, False otherwise"""
        return self._tensor.is_cuda

    def to(self, *args, **kwargs):
        """Call `torch.Tensor.to` on the underlying _tensor"""
        self._tensor = self._tensor.to(*args, **kwargs)
        return self

    def cuda(self, *args, **kwargs):
        """Call `torch.Tensor.cuda` on the underlying _tensor"""
        self._tensor = CUDALongTensor(self._tensor.cuda(*args, **kwargs))
        return self

    def cpu(self, *args, **kwargs):
        """Call `torch.Tensor.cpu` on the underlying _tensor"""
        self._tensor = self._tensor.cpu(*args, **kwargs)
        return self

    @property
    def rank(self):
        return comm.get().get_rank()

    @property
    def share(self):
        """Returns underlying _tensor"""
        return self._tensor

    @share.setter
    def share(self, value):
        """Sets _tensor to value"""
        self._tensor = value

    def shallow_copy(self):
        """Create a shallow copy"""
        result = BinarySharedTensor(src=SENTINEL)
        result.encoder = self.encoder
        result._tensor = self._tensor
        return result

    def clone(self):
        result = BinarySharedTensor(src=SENTINEL)
        result.encoder = self.encoder
        result._tensor = self._tensor.clone()
        return result

    def copy_(self, other):
        """Copies other tensor into this tensor."""
        self.share.copy_(other.share)
        self.encoder = other.encoder

    def __repr__(self):
        return f"BinarySharedTensor({self.share})"

    def __bool__(self):
        """Override bool operator since encrypted tensors cannot evaluate"""
        raise RuntimeError("Cannot evaluate BinarySharedTensors to boolean values")

    def __nonzero__(self):
        """__bool__ for backwards compatibility with Python 2"""
        raise RuntimeError("Cannot evaluate BinarySharedTensors to boolean values")

    def __ixor__(self, y):
        """Bitwise XOR operator (element-wise) in place"""
        if is_tensor(y) or isinstance(y, int):
            if self.rank == 0:
                self.share ^= y
        elif isinstance(y, BinarySharedTensor):
            self.share ^= y.share
        else:
            raise TypeError("Cannot XOR %s with %s." % (type(y), type(self)))
        return self

    def __xor__(self, y):
        """Bitwise XOR operator (element-wise)"""
        result = self.clone()
        if isinstance(y, BinarySharedTensor):
            broadcast_tensors = torch.broadcast_tensors(result.share, y.share)
            result.share = broadcast_tensors[0].clone()
        elif is_tensor(y):
            broadcast_tensors = torch.broadcast_tensors(result.share, y)
            result.share = broadcast_tensors[0].clone()
        return result.__ixor__(y)

    def __iand__(self, y):
        """Bitwise AND operator (element-wise) in place"""
        if is_tensor(y) or isinstance(y, int):
            self.share &= y
        elif isinstance(y, BinarySharedTensor):
            self.share.set_(beaver.AND(self, y).share.data)
        else:
            raise TypeError("Cannot AND %s with %s." % (type(y), type(self)))
        return self

    def __and__(self, y):
        """Bitwise AND operator (element-wise)"""
        result = self.clone()
        # TODO: Remove explicit broadcasts to allow smaller beaver triples
        if isinstance(y, BinarySharedTensor):
            broadcast_tensors = torch.broadcast_tensors(result.share, y.share)
            result.share = broadcast_tensors[0].clone()
        elif is_tensor(y):
            broadcast_tensors = torch.broadcast_tensors(result.share, y)
            result.share = broadcast_tensors[0].clone()
        return result.__iand__(y)

    def __ior__(self, y):
        """Bitwise OR operator (element-wise) in place"""
        xor_result = self ^ y
        return self.__iand__(y).__ixor__(xor_result)

    def __or__(self, y):
        """Bitwise OR operator (element-wise)"""
        return self.__and__(y) ^ self ^ y

    def __invert__(self):
        """Bitwise NOT operator (element-wise)"""
        result = self.clone()
        if result.rank == 0:
            result.share ^= -1
        return result

    def lshift_(self, value):
        """Left shift elements by `value` bits"""
        assert isinstance(value, int), "lshift must take an integer argument."
        self.share <<= value
        return self

    def lshift(self, value):
        """Left shift elements by `value` bits"""
        return self.clone().lshift_(value)

    def rshift_(self, value):
        """Right shift elements by `value` bits"""
        assert isinstance(value, int), "rshift must take an integer argument."
        self.share >>= value
        return self

    def rshift(self, value):
        """Right shift elements by `value` bits"""
        return self.clone().rshift_(value)

    # Circuits
    def add(self, y):
        """Compute [self] + [y] for xor-sharing"""
        return circuit.add(self, y)

    def eq(self, y):
        return circuit.eq(self, y)

    def ne(self, y):
        return self.eq(y) ^ 1

    def lt(self, y):
        return circuit.lt(self, y)

    def le(self, y):
        return circuit.le(self, y)

    def gt(self, y):
        return circuit.gt(self, y)

    def ge(self, y):
        return circuit.ge(self, y)

    def __setitem__(self, index, value):
        """Set tensor values by index"""
        if is_tensor(value) or isinstance(value, list):
            value = BinarySharedTensor(value)
        assert isinstance(
            value, BinarySharedTensor
        ), "Unsupported input type %s for __setitem__" % type(value)
        self.share.__setitem__(index, value.share)

    @staticmethod
    def stack(seq, *args, **kwargs):
        """Stacks a list of tensors along a given dimension"""
        assert isinstance(seq, list), "Stack input must be a list"
        assert isinstance(
            seq[0], BinarySharedTensor
        ), "Sequence must contain BinarySharedTensors"
        result = seq[0].shallow_copy()
        result.share = torch_stack(
            [BinarySharedTensor.share for BinarySharedTensor in seq], *args, **kwargs
        )
        return result

    def sum(self, dim=None):
        """Add all tensors along a given dimension using a log-reduction"""
        if dim is None:
            x = self.flatten()
        else:
            x = self.transpose(0, dim)

        # Add all BinarySharedTensors
        while x.size(0) > 1:
            extra = None
            if x.size(0) % 2 == 1:
                extra = x[0]
                x = x[1:]
            x0 = x[: (x.size(0) // 2)]
            x1 = x[(x.size(0) // 2) :]
            x = x0 + x1
            if extra is not None:
                x.share = torch_cat([x.share, extra.share.unsqueeze(0)])

        if dim is None:
            x = x.squeeze()
        else:
            x = x.transpose(0, dim).squeeze(dim)
        return x

    def cumsum(self, *args, **kwargs):
        raise NotImplementedError("BinarySharedTensor cumsum not implemented")

    def trace(self, *args, **kwargs):
        raise NotImplementedError("BinarySharedTensor trace not implemented")

    @staticmethod
    def reveal_batch(tensor_or_list, dst=None):
        """Get (batched) plaintext without any downscaling"""
        if isinstance(tensor_or_list, BinarySharedTensor):
            return tensor_or_list.reveal(dst=dst)

        assert isinstance(
            tensor_or_list, list
        ), f"Invalid input type into reveal {type(tensor_or_list)}"
        shares = [tensor.share for tensor in tensor_or_list]
        op = torch.distributed.ReduceOp.BXOR
        if dst is None:
            return comm.get().all_reduce(shares, op=op, batched=True)
        else:
            return comm.get().reduce(shares, dst, op=op, batched=True)

    def reveal(self, dst=None):
        """Get plaintext without any downscaling"""
        op = torch.distributed.ReduceOp.BXOR
        if dst is None:
            return comm.get().all_reduce(self.share, op=op)
        else:
            return comm.get().reduce(self.share, dst, op=op)

    def get_plain_text(self, dst=None):
        """Decrypts the tensor."""
        # Edge case where share becomes 0 sized (e.g. result of split)
        if self.nelement() < 1:
            return torch.empty(self.share.size())
        return self.encoder.decode(self.reveal(dst=dst))

    def where(self, condition, y):
        """Selects elements from self or y based on condition

        Args:
            condition (torch.bool or BinarySharedTensor): when True yield self,
                otherwise yield y. Note condition is not bitwise.
            y (torch.tensor or BinarySharedTensor): selected when condition is
                False.

        Returns: BinarySharedTensor or torch.tensor.
        """
        if is_tensor(condition):
            condition = condition.long()
            is_binary = ((condition == 1) | (condition == 0)).all()
            assert is_binary, "condition values must be 0 or 1"
            # -1 mult expands 0 into binary 00...00 and 1 into 11...11
            condition_expanded = -condition
            y_masked = y & (~condition_expanded)
        elif isinstance(condition, BinarySharedTensor):
            condition_expanded = condition.clone()
            # -1 mult expands binary while & 1 isolates first bit
            condition_expanded.share = -(condition_expanded.share & 1)
            # encrypted tensor must be first operand
            y_masked = (~condition_expanded) & y
        else:
            msg = f"condition {condition} must be torch.bool, or BinarySharedTensor"
            raise ValueError(msg)

        return (self & condition_expanded) ^ y_masked

    def scatter_(self, dim, index, src):
        """Writes all values from the tensor `src` into `self` at the indices
        specified in the `index` tensor. For each value in `src`, its output index
        is specified by its index in `src` for `dimension != dim` and by the
        corresponding value in `index` for `dimension = dim`.
        """
        if is_tensor(src):
            src = BinarySharedTensor(src)
        assert isinstance(
            src, BinarySharedTensor
        ), "Unrecognized scatter src type: %s" % type(src)
        self.share.scatter_(dim, index, src.share)
        return self

    def scatter(self, dim, index, src):
        """Writes all values from the tensor `src` into `self` at the indices
        specified in the `index` tensor. For each value in `src`, its output index
        is specified by its index in `src` for `dimension != dim` and by the
        corresponding value in `index` for `dimension = dim`.
        """
        result = self.clone()
        return result.scatter_(dim, index, src)

    # Bitwise operators
    __add__ = add
    __eq__ = eq
    __ne__ = ne
    __lt__ = lt
    __le__ = le
    __gt__ = gt
    __ge__ = ge
    __lshift__ = lshift
    __rshift__ = rshift

    # In-place bitwise operators
    __ilshift__ = lshift_
    __irshift__ = rshift_

    # Reversed boolean operations
    __radd__ = __add__
    __rxor__ = __xor__
    __rand__ = __and__
    __ror__ = __or__
Esempio n. 15
0
 def from_shares(share, precision=None, src=0):
     """Generate an ArithmeticSharedTensor from a share from each party"""
     result = ArithmeticSharedTensor(src=SENTINEL)
     result.share = CUDALongTensor(share) if share.is_cuda else share
     result.encoder = FixedPointEncoder(precision_bits=precision)
     return result