Exemple #1
0
 def __and__(self, y):
     """Bitwise AND operator (element-wise)"""
     result = self.clone()
     # TODO: Remove explicit broadcasts to allow smaller beaver triples
     if isinstance(y, BinarySharedTensor):
         broadcast_tensors = torch.broadcast_tensors(result.share, y.share)
         result.share = broadcast_tensors[0].clone()
     elif is_tensor(y):
         broadcast_tensors = torch.broadcast_tensors(result.share, y)
         result.share = broadcast_tensors[0].clone()
     return result.__iand__(y)
Exemple #2
0
 def scatter_(self, dim, index, src):
     """Writes all values from the tensor `src` into `self` at the indices
     specified in the `index` tensor. For each value in `src`, its output index
     is specified by its index in `src` for `dimension != dim` and by the
     corresponding value in `index` for `dimension = dim`.
     """
     if is_tensor(src):
         src = ArithmeticSharedTensor(src)
     assert isinstance(src, ArithmeticSharedTensor
                       ), "Unrecognized scatter src type: %s" % type(src)
     self.share.scatter_(dim, index, src.share)
     return self
Exemple #3
0
 def scatter_add_(self, dim, index, other):
     """Adds all values from the tensor other into self at the indices
     specified in the index tensor."""
     public = isinstance(other, (int, float)) or is_tensor(other)
     private = isinstance(other, CrypTensor)
     if public:
         self._tensor.scatter_add_(dim, index, other)
     elif private:
         self._tensor.scatter_add_(dim, index, other._tensor)
     else:
         raise TypeError("scatter_add second tensor of unsupported type")
     return self
Exemple #4
0
    def stack(tensors, *args, **kwargs):
        """Perform tensor stacking"""
        for i, tensor in enumerate(tensors):
            if is_tensor(tensor):
                tensors[i] = ArithmeticSharedTensor(tensor)
            assert isinstance(
                tensors[i], ArithmeticSharedTensor
            ), "Can't stack %s with ArithmeticSharedTensor" % type(tensor)

        result = tensors[0].shallow_copy()
        result.share = torch_stack([tensor.share for tensor in tensors], *args,
                                   **kwargs)
        return result
Exemple #5
0
    def set(self, enc_tensor):
        """
        Sets self encrypted to enc_tensor in place by setting
        shares of self to those of enc_tensor.

        Args:
            enc_tensor (MPCTensor): with encrypted shares.
        """
        if is_tensor(enc_tensor):
            enc_tensor = MPCTensor(enc_tensor)
        assert isinstance(enc_tensor, MPCTensor), "enc_tensor must be an MPCTensor"
        self.share.set_(enc_tensor.share)
        return self
Exemple #6
0
 def index_add_(self, dim, index, tensor):
     """Perform in-place index_add: Accumulate the elements of tensor into the
     self tensor by adding to the indices in the order given in index. """
     public = isinstance(tensor, (int, float)) or is_tensor(tensor)
     private = isinstance(tensor, ArithmeticSharedTensor)
     if public:
         enc_tensor = self.encoder.encode(tensor)
         if self.rank == 0:
             self._tensor.index_add_(dim, index, enc_tensor)
     elif private:
         self._tensor.index_add_(dim, index, tensor._tensor)
     else:
         raise TypeError("index_add second tensor of unsupported type")
     return self
Exemple #7
0
 def index_add_(self, dim, index, tensor):
     """Performs in-place index_add: Accumulate the elements of tensor into the
     self tensor by adding to the indices in the order given in index.
     """
     assert index.dim() == 1, "index needs to be a vector"
     public = isinstance(tensor, (int, float)) or is_tensor(tensor)
     private = isinstance(tensor, MPCTensor)
     if public:
         self._tensor.index_add_(dim, index, tensor)
     elif private:
         self._tensor.index_add_(dim, index, tensor._tensor)
     else:
         raise TypeError("index_add second tensor of unsupported type")
     return self
Exemple #8
0
 def scatter_add_(self, dim, index, other):
     """Adds all values from the tensor other into self at the indices
     specified in the index tensor in a similar fashion as scatter_(). For
     each value in other, it is added to an index in self which is specified
     by its index in other for dimension != dim and by the corresponding
     value in index for dimension = dim.
     """
     public = isinstance(other, (int, float)) or is_tensor(other)
     private = isinstance(other, ArithmeticSharedTensor)
     if public:
         if self.rank == 0:
             self.share.scatter_add_(dim, index, self.encoder.encode(other))
     elif private:
         self.share.scatter_add_(dim, index, other.share)
     else:
         raise TypeError("scatter_add second tensor of unsupported type")
     return self
Exemple #9
0
    def where(self, condition, y):
        """Selects elements from self or y based on condition

        Args:
            condition (torch.bool or ArithmeticSharedTensor): when True
                yield self, otherwise yield y.
            y (torch.tensor or ArithmeticSharedTensor): values selected at
                indices where condition is False.

        Returns: ArithmeticSharedTensor or torch.tensor
        """
        if is_tensor(condition):
            condition = condition.float()
            y_masked = y * (1 - condition)
        else:
            # encrypted tensor must be first operand
            y_masked = (1 - condition) * y

        return self * condition + y_masked
Exemple #10
0
    def _arithmetic_function(self, y, op, inplace=False, *args, **kwargs):
        assert op in [
            "add",
            "sub",
            "mul",
            "matmul",
            "conv1d",
            "conv2d",
            "conv_transpose1d",
            "conv_transpose2d",
        ], f"Provided op `{op}` is not a supported arithmetic function"

        additive_func = op in ["add", "sub"]
        public = isinstance(y, (int, float)) or is_tensor(y)
        private = isinstance(y, ArithmeticSharedTensor)

        if inplace:
            result = self
            if additive_func or (op == "mul" and public):
                op += "_"
        else:
            result = self.clone()

        if public:
            y = result.encoder.encode(y, device=self.device)

            if additive_func:  # ['add', 'sub']
                if result.rank == 0:
                    result.share = getattr(result.share, op)(y)
                else:
                    result.share = torch.broadcast_tensors(result.share, y)[0]
            elif op == "mul_":  # ['mul_']
                result.share = result.share.mul_(y)
            else:  # ['mul', 'matmul', 'convNd', 'conv_transposeNd']
                result.share = getattr(torch, op)(result.share, y, *args,
                                                  **kwargs)
        elif private:
            if additive_func:  # ['add', 'sub', 'add_', 'sub_']
                result.share = getattr(result.share, op)(y.share)
            else:  # ['mul', 'matmul', 'convNd', 'conv_transposeNd']
                # NOTE: 'mul_' calls 'mul' here
                # Must copy share.data here to support 'mul_' being inplace
                result.share.set_(
                    getattr(beaver, op)(result, y, *args, **kwargs).share.data)
        else:
            raise TypeError("Cannot %s %s with %s" % (op, type(y), type(self)))

        # Scale by encoder scale if necessary
        if not additive_func:
            if public:  # scale by self.encoder.scale
                if self.encoder.scale > 1:
                    return result.div_(result.encoder.scale)
                else:
                    result.encoder = self.encoder
            else:  # scale by larger of self.encoder.scale and y.encoder.scale
                if self.encoder.scale > 1 and y.encoder.scale > 1:
                    return result.div_(result.encoder.scale)
                elif self.encoder.scale > 1:
                    result.encoder = self.encoder
                else:
                    result.encoder = y.encoder

        return result