class MPCTensor(CrypTensor): def __init__(self, tensor, ptype=Ptype.arithmetic, device=None, *args, **kwargs): """ Creates the shared tensor from the input `tensor` provided by party `src`. The `ptype` defines the type of sharing used (default: arithmetic). The other parties can specify a `tensor` or `size` to determine the size of the shared tensor object to create. In this case, all parties must specify the same (tensor) size to prevent the party's shares from varying in size, which leads to undefined behavior. Alternatively, the parties can set `broadcast_size` to `True` to have the `src` party broadcast the correct size. The parties who do not know the tensor size beforehand can provide an empty tensor as input. This is guaranteed to produce correct behavior but requires an additional communication round. The parties can also set the `precision` and `device` for their share of the tensor. If `device` is unspecified, it is set to `tensor.device`. """ if tensor is None: raise ValueError("Cannot initialize tensor with None.") # take required_grad from kwargs, input tensor, or set to False: default = tensor.requires_grad if torch.is_tensor(tensor) else False requires_grad = kwargs.pop("requires_grad", default) # call CrypTensor constructor: super().__init__(requires_grad=requires_grad) if device is None and hasattr(tensor, "device"): device = tensor.device # create the MPCTensor: tensor_type = ptype.to_tensor() if tensor is []: self._tensor = torch.tensor([], device=device) else: self._tensor = tensor_type(tensor=tensor, device=device, *args, **kwargs) self.ptype = ptype @staticmethod def new(*args, **kwargs): """ Creates a new MPCTensor, passing all args and kwargs into the constructor. """ return MPCTensor(*args, **kwargs) @staticmethod def from_shares(share, precision=None, src=0, ptype=Ptype.arithmetic): result = MPCTensor([]) from_shares = ptype.to_tensor().from_shares result._tensor = from_shares(share, precision=precision, src=src) result.ptype = ptype return result def clone(self): """Create a deep copy of the input tensor.""" # TODO: Rename this to __deepcopy__()? result = MPCTensor([]) result._tensor = self._tensor.clone() result.ptype = self.ptype return result def shallow_copy(self): """Create a shallow copy of the input tensor.""" # TODO: Rename this to __copy__()? result = MPCTensor([]) result._tensor = self._tensor result.ptype = self.ptype return result def copy_(self, other): """Copies value of other MPCTensor into this MPCTensor.""" assert isinstance(other, MPCTensor), "other must be MPCTensor" self._tensor.copy_(other._tensor) self.ptype = other.ptype def to(self, *args, **kwargs): r""" Depending on the input arguments, converts underlying share to the given ptype or performs `torch.to` on the underlying torch tensor To convert underlying share to the given ptype, call `to` as: to(ptype, **kwargs) It will call MPCTensor.to_ptype with the arguments provided above. Otherwise, `to` performs `torch.to` on the underlying torch tensor. See https://pytorch.org/docs/stable/tensors.html?highlight=#torch.Tensor.to for a reference of the parameters that can be passed in. Args: ptype: Ptype.arithmetic or Ptype.binary. """ if "ptype" in kwargs: return self._to_ptype(**kwargs) elif args and isinstance(args[0], Ptype): ptype = args[0] return self._to_ptype(ptype, **kwargs) else: share = self.share.to(*args, **kwargs) if share.is_cuda: share = CUDALongTensor(share) self.share = share return self def _to_ptype(self, ptype, **kwargs): r""" Convert MPCTensor's underlying share to the corresponding ptype (ArithmeticSharedTensor, BinarySharedTensor) Args: ptype (Ptype.arithmetic or Ptype.binary): The ptype to convert the shares to. precision (int, optional): Precision of the fixed point encoder when converting a binary share to an arithmetic share. It will be ignored if the ptype doesn't match. bits (int, optional): If specified, will only preserve the bottom `bits` bits of a binary tensor when converting from a binary share to an arithmetic share. It will be ignored if the ptype doesn't match. """ retval = self.clone() if retval.ptype == ptype: return retval retval._tensor = convert(self._tensor, ptype, **kwargs) retval.ptype = ptype return retval def arithmetic(self): """Converts self._tensor to arithmetic secret sharing""" return self.to(Ptype.arithmetic) def binary(self): """Converts self._tensor to binary secret sharing""" return self.to(Ptype.binary) @property def device(self): """Return the `torch.device` of the underlying share""" return self.share.device @property def is_cuda(self): """Return True if the underlying share is stored on GPU, False otherwise""" return self.share.is_cuda def cuda(self, *args, **kwargs): """Call `torch.Tensor.cuda` on the underlying share""" self.share = CUDALongTensor(self.share.cuda(*args, **kwargs)) return self def cpu(self): """Call `torch.Tensor.cpu` on the underlying share""" self.share = self.share.cpu() return self def get_plain_text(self, dst=None): """Decrypts the tensor.""" return self._tensor.get_plain_text(dst=dst) def reveal(self, dst=None): """Decrypts the tensor without any downscaling.""" return self._tensor.reveal(dst=dst) def __bool__(self): """Override bool operator since encrypted tensors cannot evaluate""" raise RuntimeError("Cannot evaluate MPCTensors to boolean values") def __nonzero__(self): """__bool__ for backwards compatibility with Python 2""" raise RuntimeError("Cannot evaluate MPCTensors to boolean values") def __repr__(self): """Returns a representation of the tensor useful for debugging.""" from crypten.debug import debug_mode share = self.share plain_text = self._tensor.get_plain_text() if debug_mode( ) else "HIDDEN" ptype = self.ptype return (f"MPCTensor(\n\t_tensor={share}\n" f"\tplain_text={plain_text}\n\tptype={ptype}\n)") def __setitem__(self, index, value): """Set tensor values by index""" if not isinstance(value, MPCTensor): value = MPCTensor(value, ptype=self.ptype, device=self.device) self._tensor.__setitem__(index, value._tensor) @property def share(self): """Returns underlying share""" return self._tensor.share @share.setter def share(self, value): """Sets share to value""" self._tensor.share = value @property def encoder(self): """Returns underlying encoder""" return self._tensor.encoder @encoder.setter def encoder(self, value): """Sets encoder to value""" self._tensor.encoder = value @staticmethod def __cat_stack_helper(op, tensors, *args, **kwargs): assert op in ["cat", "stack"], "Unsupported op for helper function" assert isinstance(tensors, list), "%s input must be a list" % op assert len(tensors) > 0, "expected a non-empty list of MPCTensors" _ptype = kwargs.pop("ptype", None) # Populate ptype field if _ptype is None: for tensor in tensors: if isinstance(tensor, MPCTensor): _ptype = tensor.ptype break if _ptype is None: _ptype = Ptype.arithmetic # Make all inputs MPCTensors of given ptype for i, tensor in enumerate(tensors): if tensor.ptype != _ptype: tensors[i] = tensor.to(_ptype) # Operate on all input tensors result = tensors[0].clone() funcs = {"cat": torch_cat, "stack": torch_stack} result.share = funcs[op]([tensor.share for tensor in tensors], *args, **kwargs) return result @staticmethod def cat(tensors, *args, **kwargs): """Perform matrix concatenation""" return MPCTensor.__cat_stack_helper("cat", tensors, *args, **kwargs) @staticmethod def stack(tensors, *args, **kwargs): """Perform tensor stacking""" return MPCTensor.__cat_stack_helper("stack", tensors, *args, **kwargs) @staticmethod def rand(*sizes, device=None): """ Returns a tensor with elements uniformly sampled in [0, 1). The uniform random samples are generated by generating random bits using fixed-point encoding and converting the result to an ArithmeticSharedTensor. """ rand = MPCTensor([]) encoder = FixedPointEncoder() rand._tensor = BinarySharedTensor.rand(*sizes, bits=encoder._precision_bits) rand._tensor.encoder = encoder rand.ptype = Ptype.binary return rand.to(Ptype.arithmetic, bits=encoder._precision_bits) @staticmethod def randn(*sizes, device=None): """ Returns a tensor with normally distributed elements. Samples are generated using the Box-Muller transform with optimizations for numerical precision and MPC efficiency. """ u = MPCTensor.rand(*sizes).flatten() odd_numel = u.numel() % 2 == 1 if odd_numel: u = MPCTensor.cat([u, MPCTensor.rand((1, ))]) n = u.numel() // 2 u1 = u[:n] u2 = u[n:] # Radius = sqrt(- 2 * log(u1)) r2 = -2 * u1.log(input_in_01=True) r = r2.sqrt() # Theta = cos(2 * pi * u2) or sin(2 * pi * u2) cos, sin = u2.sub(0.5).mul(6.28318531).cossin() # Generating 2 independent normal random variables using x = r.mul(sin) y = r.mul(cos) z = MPCTensor.cat([x, y]) if odd_numel: z = z[1:] return z.view(*sizes) def bernoulli(self): """Returns a tensor with elements in {0, 1}. The i-th element of the output will be 1 with probability according to the i-th value of the input tensor.""" return self > MPCTensor.rand(self.size(), device=self.device) # TODO: It seems we can remove all Dropout implementations below? def dropout(self, p=0.5, training=True, inplace=False): r""" Randomly zeroes some of the elements of the input tensor with probability :attr:`p`. Args: p: probability of a channel to be zeroed. Default: 0.5 training: apply dropout if is ``True``. Default: ``True`` inplace: If set to ``True``, will do this operation in-place. Default: ``False`` """ assert p >= 0.0 and p <= 1.0, "dropout probability has to be between 0 and 1" if not training: if inplace: return self else: return self.clone() rand_tensor = MPCTensor.rand(self.size(), device=self.device) dropout_tensor = rand_tensor > p if inplace: result_tensor = self.mul_(dropout_tensor).div_(1 - p) else: result_tensor = self.mul(dropout_tensor).div_(1 - p) return result_tensor def dropout2d(self, p=0.5, training=True, inplace=False): r""" Randomly zero out entire channels (a channel is a 2D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample in the batched input is a 2D tensor :math:`\text{input}[i, j]`) of the input tensor). Each channel will be zeroed out independently on every forward call with probability :attr:`p` using samples from a Bernoulli distribution. Args: p: probability of a channel to be zeroed. Default: 0.5 training: apply dropout if is ``True``. Default: ``True`` inplace: If set to ``True``, will do this operation in-place. Default: ``False`` """ assert p >= 0.0 and p <= 1.0, "dropout probability has to be between 0 and 1" return self._feature_dropout(p, training, inplace) def dropout3d(self, p=0.5, training=True, inplace=False): r""" Randomly zero out entire channels (a channel is a 3D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample in the batched input is a 3D tensor :math:`\text{input}[i, j]`) of the input tensor). Each channel will be zeroed out independently on every forward call with probability :attr:`p` using samples from a Bernoulli distribution. Args: p: probability of a channel to be zeroed. Default: 0.5 training: apply dropout if is ``True``. Default: ``True`` inplace: If set to ``True``, will do this operation in-place. Default: ``False`` """ # This is 100% the same code as dropout2d. We duplicate this code so that # stack traces are not confusing. assert p >= 0.0 and p <= 1.0, "dropout probability has to be between 0 and 1" return self._feature_dropout(p, training, inplace) def _feature_dropout(self, p=0.5, training=True, inplace=False): """Randomly zeros out entire channels in the input tensor with probability :attr:`p`. (a channel is a nD feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample in the batched input is a nD tensor :math:`\text{input}[i, j]`).""" assert self.dim( ) >= 2, "feature dropout requires dimension to be at least 2" assert p >= 0.0 and p <= 1.0, "dropout probability has to be between 0 and 1" if not training: if inplace: return self else: return self.clone() # take first 2 dimensions feature_dropout_size = self.size()[0:2] # create dropout tensor over the first two dimensions rand_tensor = MPCTensor.rand(feature_dropout_size, device=self.device) feature_dropout_tensor = rand_tensor > p # Broadcast to remaining dimensions for i in range(2, self.dim()): feature_dropout_tensor = feature_dropout_tensor.unsqueeze(i) feature_dropout_tensor.share, self.share = torch.broadcast_tensors( feature_dropout_tensor.share, self.share) if inplace: result_tensor = self.mul_(feature_dropout_tensor).div_(1 - p) else: result_tensor = self.mul(feature_dropout_tensor).div_(1 - p) return result_tensor # Comparators @mode(Ptype.binary) def _ltz(self, _scale=True): """Returns 1 for elements that are < 0 and 0 otherwise""" shift = torch.iinfo(torch.long).bits - 1 result = (self >> shift).to(Ptype.arithmetic, bits=1) if _scale: return result * result.encoder._scale else: result.encoder._scale = 1 return result @mode(Ptype.arithmetic) def ge(self, y, _scale=True): """Returns self >= y""" return 1 - self.lt(y, _scale=_scale) @mode(Ptype.arithmetic) def gt(self, y, _scale=True): """Returns self > y""" return (-self + y)._ltz(_scale=_scale) @mode(Ptype.arithmetic) def le(self, y, _scale=True): """Returns self <= y""" return 1 - self.gt(y, _scale=_scale) @mode(Ptype.arithmetic) def lt(self, y, _scale=True): """Returns self < y""" return (self - y)._ltz(_scale=_scale) @mode(Ptype.arithmetic) def eq(self, y, _scale=True): """Returns self == y""" if comm.get().get_world_size() == 2: return (self - y)._eqz_2PC(_scale=_scale) return 1 - self.ne(y, _scale=_scale) @mode(Ptype.arithmetic) def ne(self, y, _scale=True): """Returns self != y""" if comm.get().get_world_size() == 2: return 1 - self.eq(y, _scale=_scale) difference = self - y difference.share = torch_stack([difference.share, -(difference.share)]) return difference._ltz(_scale=_scale).sum(0) @mode(Ptype.arithmetic) def _eqz_2PC(self, _scale=True): """Returns self == 0""" # Create BinarySharedTensors from shares x0 = MPCTensor(self.share, src=0, ptype=Ptype.binary) x1 = MPCTensor(-self.share, src=1, ptype=Ptype.binary) # Perform equality testing using binary shares x0._tensor = x0._tensor.eq(x1._tensor) x0.encoder = x0.encoder if _scale else self.encoder # Convert to Arithmetic sharing result = x0.to(Ptype.arithmetic, bits=1) if not _scale: result.encoder._scale = 1 return result @mode(Ptype.arithmetic) def sign(self, _scale=True): """Computes the sign value of a tensor (0 is considered positive)""" return 1 - 2 * self._ltz(_scale=_scale) @mode(Ptype.arithmetic) def abs(self): """Computes the absolute value of a tensor""" return self * self.sign(_scale=False) @mode(Ptype.arithmetic) def relu(self): """Compute a Rectified Linear function on the input tensor.""" return self * self.ge(0, _scale=False) @mode(Ptype.arithmetic) def weighted_index(self, dim=None): """ Returns a tensor with entries that are one-hot along dimension `dim`. These one-hot entries are set at random with weights given by the input `self`. Examples:: >>> encrypted_tensor = MPCTensor(torch.tensor([1., 6.])) >>> index = encrypted_tensor.weighted_index().get_plain_text() # With 1 / 7 probability torch.tensor([1., 0.]) # With 6 / 7 probability torch.tensor([0., 1.]) """ if dim is None: return self.flatten().weighted_index(dim=0).view(self.size()) x = self.cumsum(dim) max_weight = x.index_select( dim, torch.tensor(x.size(dim) - 1, device=self.device)) r = MPCTensor.rand(max_weight.size(), device=self.device) * max_weight gt = x.gt(r, _scale=False) shifted = gt.roll(1, dims=dim) shifted.share.index_fill_(dim, torch.tensor(0, device=self.device), 0) return gt - shifted @mode(Ptype.arithmetic) def weighted_sample(self, dim=None): """ Samples a single value across dimension `dim` with weights corresponding to the values in `self` Returns the sample and the one-hot index of the sample. Examples:: >>> encrypted_tensor = MPCTensor(torch.tensor([1., 6.])) >>> index = encrypted_tensor.weighted_sample().get_plain_text() # With 1 / 7 probability (torch.tensor([1., 0.]), torch.tensor([1., 0.])) # With 6 / 7 probability (torch.tensor([0., 6.]), torch.tensor([0., 1.])) """ indices = self.weighted_index(dim) sample = self.mul(indices).sum(dim) return sample, indices # max / min-related functions @mode(Ptype.arithmetic) def argmax(self, dim=None, keepdim=False, one_hot=True): """Returns the indices of the maximum value of all elements in the `input` tensor. """ # TODO: Make dim an arg. if self.dim() == 0: result = (MPCTensor(torch.ones( (), device=self.device)) if one_hot else MPCTensor( torch.zeros((), device=self.device))) return result result = _argmax_helper(self, dim, one_hot, config.max_method, _return_max=False) if not one_hot: result = _one_hot_to_index(result, dim, keepdim, self.device) return result @mode(Ptype.arithmetic) def argmin(self, dim=None, keepdim=False, one_hot=True): """Returns the indices of the minimum value of all elements in the `input` tensor. """ # TODO: Make dim an arg. return (-self).argmax(dim=dim, keepdim=keepdim, one_hot=one_hot) @mode(Ptype.arithmetic) def max(self, dim=None, keepdim=False, one_hot=True): """Returns the maximum value of all elements in the input tensor.""" # TODO: Make dim an arg. method = config.max_method if dim is None: if method in ["log_reduction", "double_log_reduction"]: # max_result can be obtained directly max_result = _max_helper_all_tree_reductions(self, method=method) else: # max_result needs to be obtained through argmax with ConfigManager("max_method", method): argmax_result = self.argmax(one_hot=True) max_result = self.mul(argmax_result).sum() return max_result else: argmax_result, max_result = _argmax_helper(self, dim=dim, one_hot=True, method=method, _return_max=True) if max_result is None: max_result = (self * argmax_result).sum(dim=dim, keepdim=keepdim) if keepdim: max_result = (max_result.unsqueeze(dim) if max_result.dim() < self.dim() else max_result) if one_hot: return max_result, argmax_result else: return ( max_result, _one_hot_to_index(argmax_result, dim, keepdim, self.device), ) @mode(Ptype.arithmetic) def min(self, dim=None, keepdim=False, one_hot=True): """Returns the minimum value of all elements in the input tensor.""" # TODO: Make dim an arg. result = (-self).max(dim=dim, keepdim=keepdim, one_hot=one_hot) if dim is None: return -result else: return -result[0], result[1] @mode(Ptype.arithmetic) def max_pool2d(self, kernel_size, padding=None, stride=None, return_indices=False): """Applies a 2D max pooling over an input signal composed of several input planes. """ max_input = self.shallow_copy() max_input.share, output_size = pool_reshape( self.share, kernel_size, padding=padding, stride=stride, # padding with extremely negative values to avoid choosing pads # -2 ** 33 is acceptable since it is lower than the supported range # which is -2 ** 32 because multiplication can otherwise fail. pad_value=(-(2**33)), ) max_vals, argmax_vals = max_input.max(dim=-1, one_hot=True) max_vals = max_vals.view(output_size) if return_indices: if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) argmax_vals = argmax_vals.view(output_size + kernel_size) return max_vals, argmax_vals return max_vals @mode(Ptype.arithmetic) def _max_pool2d_backward(self, indices, kernel_size, padding=None, stride=None, output_size=None): """Implements the backwards for a `max_pool2d` call.""" # Setup padding if padding is None: padding = 0 if isinstance(padding, int): padding = padding, padding assert isinstance(padding, tuple), "padding must be a int, tuple, or None" p0, p1 = padding # Setup stride if stride is None: stride = kernel_size if isinstance(stride, int): stride = stride, stride assert isinstance(padding, tuple), "stride must be a int, tuple, or None" s0, s1 = stride # Setup kernel_size if isinstance(kernel_size, int): kernel_size = kernel_size, kernel_size assert isinstance(padding, tuple), "padding must be a int or tuple" k0, k1 = kernel_size assert self.dim( ) == 4, "Input to _max_pool2d_backward must have 4 dimensions" assert ( indices.dim() == 6 ), "Indices input for _max_pool2d_backward must have 6 dimensions" # Computes one-hot gradient blocks from each output variable that # has non-zero value corresponding to the argmax of the corresponding # block of the max_pool2d input. kernels = self.view(self.size() + (1, 1)) * indices # Use minimal size if output_size is not specified. if output_size is None: output_size = ( self.size(0), self.size(1), s0 * self.size(2) - 2 * p0, s1 * self.size(3) - 2 * p1, ) # Sum the one-hot gradient blocks at corresponding index locations. result = MPCTensor(torch.zeros(output_size)).pad([p0, p0, p1, p1]) for i in range(self.size(2)): for j in range(self.size(3)): left_ind = s0 * i top_ind = s1 * j result[:, :, left_ind:left_ind + k0, top_ind:top_ind + k1] += kernels[:, :, i, j] result = result[:, :, p0:result.size(2) - p0, p1:result.size(3) - p1] return result def adaptive_avg_pool2d(self, output_size): r""" Applies a 2D adaptive average pooling over an input signal composed of several input planes. See :class:`~torch.nn.AdaptiveAvgPool2d` for details and output shape. Args: output_size: the target output size (single integer or double-integer tuple) """ resized_input, args, kwargs = adaptive_pool2d_helper(self, output_size, reduction="mean") return resized_input.avg_pool2d(*args, **kwargs) def adaptive_max_pool2d(self, output_size, return_indices=False): r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes. See :class:`~torch.nn.AdaptiveMaxPool2d` for details and output shape. Args: output_size: the target output size (single integer or double-integer tuple) return_indices: whether to return pooling indices. Default: ``False`` """ resized_input, args, kwargs = adaptive_pool2d_helper(self, output_size, reduction="max") return resized_input.max_pool2d(*args, **kwargs, return_indices=return_indices) def where(self, condition, y): """Selects elements from self or y based on condition Args: condition (torch.bool or MPCTensor): when True yield self, otherwise yield y y (torch.tensor or MPCTensor): values selected at indices where condition is False. Returns: MPCTensor or torch.tensor """ if is_tensor(condition): condition = condition.float() y_masked = y * (1 - condition) else: # encrypted tensor must be first operand y_masked = (1 - condition) * y return self * condition + y_masked @mode(Ptype.arithmetic) def pad(self, pad, mode="constant", value=0): result = self.shallow_copy() if isinstance(value, MPCTensor): result._tensor = self._tensor.pad(pad, mode=mode, value=value._tensor) else: result._tensor = self._tensor.pad(pad, mode=mode, value=value) return result @mode(Ptype.arithmetic) def polynomial(self, coeffs, func="mul"): """Computes a polynomial function on a tensor with given coefficients, `coeffs`, that can be a list of values or a 1-D tensor. Coefficients should be ordered from the order 1 (linear) term first, ending with the highest order term. (Constant is not included). """ # Coefficient input type-checking if isinstance(coeffs, list): coeffs = torch.tensor(coeffs, device=self.device) assert is_tensor(coeffs) or crypten.is_encrypted_tensor( coeffs), "Polynomial coefficients must be a list or tensor" assert coeffs.dim( ) == 1, "Polynomial coefficients must be a 1-D tensor" # Handle linear case if coeffs.size(0) == 1: return self.mul(coeffs) # Compute terms of polynomial using exponentially growing tree terms = crypten.stack([self, self.square()]) while terms.size(0) < coeffs.size(0): highest_term = terms.index_select( 0, torch.tensor(terms.size(0) - 1, device=self.device)) new_terms = getattr(terms, func)(highest_term) terms = crypten.cat([terms, new_terms]) # Resize the coefficients for broadcast terms = terms[:coeffs.size(0)] for _ in range(terms.dim() - 1): coeffs = coeffs.unsqueeze(1) # Multiply terms by coefficients and sum return terms.mul(coeffs).sum(0) def div(self, y): r"""Divides each element of :attr:`self` with the scalar :attr:`y` or each element of the tensor :attr:`y` and returns a new resulting tensor. For `y` a scalar: .. math:: \text{out}_i = \frac{\text{self}_i}{\text{y}} For `y` a tensor: .. math:: \text{out}_i = \frac{\text{self}_i}{\text{y}_i} Note for :attr:`y` a tensor, the shapes of :attr:`self` and :attr:`y` must be `broadcastable`_. .. _broadcastable: https://pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-semantics""" # noqa: B950 result = self.clone() if isinstance(y, CrypTensor): result.share = torch.broadcast_tensors(result.share, y.share)[0].clone() elif is_tensor(y): result.share = torch.broadcast_tensors(result.share, y)[0].clone() return result.div_(y) def div_(self, y): """In-place version of :meth:`div`""" if isinstance(y, MPCTensor): return self.mul_(y.reciprocal()) self._tensor.div_(y) return self def pow(self, p, **kwargs): """ Computes an element-wise exponent `p` of a tensor, where `p` is an integer. """ if isinstance(p, float) and int(p) == p: p = int(p) if not isinstance(p, int): raise TypeError( "pow must take an integer exponent. For non-integer powers, use" " pos_pow with positive-valued base.") if p < -1: return self.reciprocal().pow(-p) elif p == -1: return self.reciprocal() elif p == 0: # Note: This returns 0 ** 0 -> 1 when inputs have zeros. # This is consistent with PyTorch's pow function. return MPCTensor(torch.ones_like(self.share)) elif p == 1: return self.clone() elif p == 2: return self.square() elif p % 2 == 0: return self.square().pow(p // 2) else: return self.square().mul_(self).pow((p - 1) // 2) def pow_(self, p, **kwargs): """In-place version of pow_ function""" result = self.pow(p) self.share.set_(result.share.data) return self def pos_pow(self, p): """ Approximates self ** p by computing: :math:`x^p = exp(p * log(x))` Note that this requires that the base `self` contain only positive values since log can only be computed on positive numbers. Note that the value of `p` can be an integer, float, public tensor, or encrypted tensor. """ if isinstance(p, int) or (isinstance(p, float) and int(p) == p): return self.pow(p) return self.log().mul_(p).exp() def norm(self, p="fro", dim=None, keepdim=False): """Computes the p-norm of the input tensor (or along a dimension).""" if p == "fro": p = 2 if isinstance(p, (int, float)): assert p >= 1, "p-norm requires p >= 1" if p == 1: if dim is None: return self.abs().sum() return self.abs().sum(dim, keepdim=keepdim) elif p == 2: if dim is None: return self.square().sum().sqrt() return self.square().sum(dim, keepdim=keepdim).sqrt() elif p == float("inf"): if dim is None: return self.abs().max() return self.abs().max(dim=dim, keepdim=keepdim)[0] else: if dim is None: return self.abs().pos_pow(p).sum().pos_pow(1 / p) return self.abs().pos_pow(p).sum(dim, keepdim=keepdim).pos_pow( 1 / p) elif p == "nuc": raise NotImplementedError("Nuclear norm is not implemented") else: raise ValueError(f"Improper value p ({p})for p-norm") def index_add(self, dim, index, tensor): """Performs out-of-place index_add: Accumulate the elements of tensor into the self tensor by adding to the indices in the order given in index. """ return self.clone().index_add_(dim, index, tensor) def index_add_(self, dim, index, tensor): """Performs in-place index_add: Accumulate the elements of tensor into the self tensor by adding to the indices in the order given in index. """ assert index.dim() == 1, "index needs to be a vector" public = isinstance(tensor, (int, float)) or is_tensor(tensor) private = isinstance(tensor, MPCTensor) if public: self._tensor.index_add_(dim, index, tensor) elif private: self._tensor.index_add_(dim, index, tensor._tensor) else: raise TypeError("index_add second tensor of unsupported type") return self def scatter_add(self, dim, index, other): """Adds all values from the tensor other into self at the indices specified in the index tensor. """ return self.clone().scatter_add_(dim, index, other) def scatter_add_(self, dim, index, other): """Adds all values from the tensor other into self at the indices specified in the index tensor.""" public = isinstance(other, (int, float)) or is_tensor(other) private = isinstance(other, CrypTensor) if public: self._tensor.scatter_add_(dim, index, other) elif private: self._tensor.scatter_add_(dim, index, other._tensor) else: raise TypeError("scatter_add second tensor of unsupported type") return self def scatter_(self, dim, index, src): """Writes all values from the tensor `src` into `self` at the indices specified in the `index` tensor. For each value in `src`, its output index is specified by its index in `src` for `dimension != dim` and by the corresponding value in `index` for `dimension = dim`. """ if is_tensor(src): src = MPCTensor(src) assert isinstance( src, MPCTensor), "Unrecognized scatter src type: %s" % type(src) self.share.scatter_(dim, index, src.share) return self def scatter(self, dim, index, src): """Out-of-place version of :meth:`MPCTensor.scatter_`""" result = self.clone() return result.scatter_(dim, index, src) def unbind(self, dim=0): shares = self.share.unbind(dim=dim) results = tuple( MPCTensor(0, ptype=self.ptype, device=self.device) for _ in range(len(shares))) for i in range(len(shares)): results[i].share = shares[i] return results def split(self, split_size, dim=0): shares = self.share.split(split_size, dim=dim) results = tuple( MPCTensor(0, ptype=self.ptype, device=self.device) for _ in range(len(shares))) for i in range(len(shares)): results[i].share = shares[i] return results def set(self, enc_tensor): """ Sets self encrypted to enc_tensor in place by setting shares of self to those of enc_tensor. Args: enc_tensor (MPCTensor): with encrypted shares. """ if is_tensor(enc_tensor): enc_tensor = MPCTensor(enc_tensor) assert isinstance(enc_tensor, MPCTensor), "enc_tensor must be an MPCTensor" self.share.set_(enc_tensor.share) return self
class ArithmeticSharedTensor(object): """ Encrypted tensor object that uses additive sharing to perform computations. Additive shares are computed by splitting each value of the input tensor into n separate random values that add to the input tensor, where n is the number of parties present in the protocol (world_size). """ # constructors: def __init__( self, tensor=None, size=None, broadcast_size=False, precision=None, src=0, device=None, ): """ Creates the shared tensor from the input `tensor` provided by party `src`. The other parties can specify a `tensor` or `size` to determine the size of the shared tensor object to create. In this case, all parties must specify the same (tensor) size to prevent the party's shares from varying in size, which leads to undefined behavior. Alternatively, the parties can set `broadcast_size` to `True` to have the `src` party broadcast the correct size. The parties who do not know the tensor size beforehand can provide an empty tensor as input. This is guaranteed to produce correct behavior but requires an additional communication round. The parties can also set the `precision` and `device` for their share of the tensor. If `device` is unspecified, it is set to `tensor.device`. """ # do nothing if source is sentinel: if src == SENTINEL: return # assertions on inputs: assert (isinstance(src, int) and src >= 0 and src < comm.get().get_world_size() ), "specified source party does not exist" if self.rank == src: assert tensor is not None, "source must provide a data tensor" if hasattr(tensor, "src"): assert ( tensor.src == src ), "source of data tensor must match source of encryption" if not broadcast_size: assert (tensor is not None or size is not None ), "must specify tensor or size, or set broadcast_size" # if device is unspecified, try and get it from tensor: if device is None and tensor is not None and hasattr(tensor, "device"): device = tensor.device # encode the input tensor: self.encoder = FixedPointEncoder(precision_bits=precision) if tensor is not None: if is_int_tensor(tensor) and precision != 0: tensor = tensor.float() tensor = self.encoder.encode(tensor) tensor = tensor.to(device=device) size = tensor.size() # if other parties do not know tensor's size, broadcast the size: if broadcast_size: size = comm.get().broadcast_obj(size, src) # generate pseudo-random zero sharing (PRZS) and add source's tensor: self.share = ArithmeticSharedTensor.PRZS(size, device=device).share if self.rank == src: self.share += tensor @staticmethod def new(*args, **kwargs): """ Creates a new ArithmeticSharedTensor, passing all args and kwargs into the constructor. """ return ArithmeticSharedTensor(*args, **kwargs) @property def device(self): """Return the `torch.device` of the underlying _tensor""" return self._tensor.device @property def is_cuda(self): """Return True if the underlying _tensor is stored on GPU, False otherwise""" return self._tensor.is_cuda def to(self, *args, **kwargs): """Call `torch.Tensor.to` on the underlying _tensor""" self._tensor = self._tensor.to(*args, **kwargs) return self def cuda(self, *args, **kwargs): """Call `torch.Tensor.cuda` on the underlying _tensor""" self._tensor = CUDALongTensor(self._tensor.cuda(*args, **kwargs)) return self def cpu(self, *args, **kwargs): """Call `torch.Tensor.cpu` on the underlying _tensor""" self._tensor = self._tensor.cpu(*args, **kwargs) return self @property def share(self): """Returns underlying _tensor""" return self._tensor @share.setter def share(self, value): """Sets _tensor to value""" self._tensor = value @staticmethod def from_shares(share, precision=None, device=None): """Generate an ArithmeticSharedTensor from a share from each party""" result = ArithmeticSharedTensor(src=SENTINEL) share = share.to(device) if device is not None else share result.share = CUDALongTensor(share) if share.is_cuda else share result.encoder = FixedPointEncoder(precision_bits=precision) return result @staticmethod def PRZS(*size, device=None): """ Generate a Pseudo-random Sharing of Zero (using arithmetic shares) This function does so by generating `n` numbers across `n` parties with each number being held by exactly 2 parties. One of these parties adds this number while the other subtracts this number. """ from crypten import generators tensor = ArithmeticSharedTensor(src=SENTINEL) if device is None: device = torch.device("cpu") elif isinstance(device, str): device = torch.device(device) g0 = generators["prev"][device] g1 = generators["next"][device] current_share = generate_random_ring_element(*size, generator=g0, device=device) next_share = generate_random_ring_element(*size, generator=g1, device=device) tensor.share = current_share - next_share return tensor @staticmethod def PRSS(*size, device=None): """ Generates a Pseudo-random Secret Share from a set of random arithmetic shares """ share = generate_random_ring_element(*size, device=device) tensor = ArithmeticSharedTensor.from_shares(share=share) return tensor @property def rank(self): return comm.get().get_rank() def shallow_copy(self): """Create a shallow copy""" result = ArithmeticSharedTensor(src=SENTINEL) result.encoder = self.encoder result._tensor = self._tensor return result def clone(self): result = ArithmeticSharedTensor(src=SENTINEL) result.encoder = self.encoder result._tensor = self._tensor.clone() return result def copy_(self, other): """Copies other tensor into this tensor.""" self.share.copy_(other.share) self.encoder = other.encoder def __repr__(self): return f"ArithmeticSharedTensor({self.share})" def __bool__(self): """Override bool operator since encrypted tensors cannot evaluate""" raise RuntimeError( "Cannot evaluate ArithmeticSharedTensors to boolean values") def __nonzero__(self): """__bool__ for backwards compatibility with Python 2""" raise RuntimeError( "Cannot evaluate ArithmeticSharedTensors to boolean values") def __setitem__(self, index, value): """Set tensor values by index""" if isinstance(value, (int, float)) or is_tensor(value): value = ArithmeticSharedTensor(value) assert isinstance( value, ArithmeticSharedTensor ), "Unsupported input type %s for __setitem__" % type(value) self.share.__setitem__(index, value.share) def pad(self, pad, mode="constant", value=0): """ Pads the input tensor with values provided in `value`. """ assert mode == "constant", ( "Padding with mode %s is currently unsupported" % mode) result = self.shallow_copy() if isinstance(value, (int, float)): value = self.encoder.encode(value).item() if result.rank == 0: result.share = torch.nn.functional.pad(result.share, pad, mode=mode, value=value) else: result.share = torch.nn.functional.pad(result.share, pad, mode=mode, value=0) elif isinstance(value, ArithmeticSharedTensor): assert (value.dim() == 0 ), "Private values used for padding must be 0-dimensional" value = value.share.item() result.share = torch.nn.functional.pad(result.share, pad, mode=mode, value=value) else: raise TypeError( "Cannot pad ArithmeticSharedTensor with a %s value" % type(value)) return result @staticmethod def stack(tensors, *args, **kwargs): """Perform tensor stacking""" for i, tensor in enumerate(tensors): if is_tensor(tensor): tensors[i] = ArithmeticSharedTensor(tensor) assert isinstance( tensors[i], ArithmeticSharedTensor ), "Can't stack %s with ArithmeticSharedTensor" % type(tensor) result = tensors[0].shallow_copy() result.share = torch_stack([tensor.share for tensor in tensors], *args, **kwargs) return result @staticmethod def reveal_batch(tensor_or_list, dst=None): """Get (batched) plaintext without any downscaling""" if isinstance(tensor_or_list, ArithmeticSharedTensor): return tensor_or_list.reveal(dst=dst) assert isinstance( tensor_or_list, list), f"Invalid input type into reveal {type(tensor_or_list)}" shares = [tensor.share for tensor in tensor_or_list] if dst is None: return comm.get().all_reduce(shares, batched=True) else: return comm.get().reduce(shares, dst, batched=True) def reveal(self, dst=None): """Decrypts the tensor without any downscaling.""" tensor = self.share.clone() if dst is None: return comm.get().all_reduce(tensor) else: return comm.get().reduce(tensor, dst) def get_plain_text(self, dst=None): """Decrypts the tensor.""" # Edge case where share becomes 0 sized (e.g. result of split) if self.nelement() < 1: return torch.empty(self.share.size()) return self.encoder.decode(self.reveal(dst=dst)) def encode_(self, new_encoder): """Rescales the input to a new encoding in-place""" if self.encoder.scale == new_encoder.scale: return self elif self.encoder.scale < new_encoder.scale: scale_factor = new_encoder.scale // self.encoder.scale self.share *= scale_factor else: scale_factor = self.encoder.scale // new_encoder.scale self = self.div_(scale_factor) self.encoder = new_encoder return self def encode(self, new_encoder): """Rescales the input to a new encoding""" return self.clone().encode_(new_encoder) def encode_as_(self, other): """Rescales self to have the same encoding as other""" return self.encode_(other.encoder) def encode_as(self, other): return self.encode(other.encoder) def _arithmetic_function_(self, y, op, *args, **kwargs): return self._arithmetic_function(y, op, inplace=True, *args, **kwargs) def _arithmetic_function(self, y, op, inplace=False, *args, **kwargs): # noqa:C901 assert op in [ "add", "sub", "mul", "matmul", "conv1d", "conv2d", "conv_transpose1d", "conv_transpose2d", ], f"Provided op `{op}` is not a supported arithmetic function" additive_func = op in ["add", "sub"] public = isinstance(y, (int, float)) or is_tensor(y) private = isinstance(y, ArithmeticSharedTensor) if inplace: result = self if additive_func or (op == "mul" and public): op += "_" else: result = self.clone() if public: y = result.encoder.encode(y, device=self.device) if additive_func: # ['add', 'sub'] if result.rank == 0: result.share = getattr(result.share, op)(y) else: result.share = torch.broadcast_tensors(result.share, y)[0] elif op == "mul_": # ['mul_'] result.share = result.share.mul_(y) else: # ['mul', 'matmul', 'convNd', 'conv_transposeNd'] result.share = getattr(torch, op)(result.share, y, *args, **kwargs) elif private: if additive_func: # ['add', 'sub', 'add_', 'sub_'] # Re-encode if necessary: if self.encoder.scale > y.encoder.scale: y.encode_as_(result) elif self.encoder.scale < y.encoder.scale: result.encode_as_(y) result.share = getattr(result.share, op)(y.share) else: # ['mul', 'matmul', 'convNd', 'conv_transposeNd'] protocol = globals()[cfg.mpc.protocol] result.share.set_( getattr(protocol, op)(result, y, *args, **kwargs).share.data) else: raise TypeError("Cannot %s %s with %s" % (op, type(y), type(self))) # Scale by encoder scale if necessary if not additive_func: if public: # scale by self.encoder.scale if self.encoder.scale > 1: return result.div_(result.encoder.scale) else: result.encoder = self.encoder else: # scale by larger of self.encoder.scale and y.encoder.scale if self.encoder.scale > 1 and y.encoder.scale > 1: return result.div_(result.encoder.scale) elif self.encoder.scale > 1: result.encoder = self.encoder else: result.encoder = y.encoder return result def add(self, y): """Perform element-wise addition""" return self._arithmetic_function(y, "add") def add_(self, y): """Perform element-wise addition""" return self._arithmetic_function_(y, "add") def sub(self, y): """Perform element-wise subtraction""" return self._arithmetic_function(y, "sub") def sub_(self, y): """Perform element-wise subtraction""" return self._arithmetic_function_(y, "sub") def mul(self, y): """Perform element-wise multiplication""" if isinstance(y, int): result = self.clone() result.share = self.share * y return result return self._arithmetic_function(y, "mul") def mul_(self, y): """Perform element-wise multiplication""" if isinstance(y, int) or is_int_tensor(y): self.share *= y return self return self._arithmetic_function_(y, "mul") def div(self, y): """Divide by a given tensor""" result = self.clone() if isinstance(y, CrypTensor): result.share = torch.broadcast_tensors(result.share, y.share)[0].clone() elif is_tensor(y): result.share = torch.broadcast_tensors(result.share, y)[0].clone() return result.div_(y) def div_(self, y): """Divide two tensors element-wise""" # TODO: Add test coverage for this code path (next 4 lines) if isinstance(y, float) and int(y) == y: y = int(y) if is_float_tensor(y) and y.frac().eq(0).all(): y = y.long() if isinstance(y, int) or is_int_tensor(y): validate = cfg.debug.validation_mode if validate: tolerance = 1.0 tensor = self.get_plain_text() # Truncate protocol for dividing by public integers: if comm.get().get_world_size() > 2: protocol = globals()[cfg.mpc.protocol] protocol.truncate(self, y) else: self.share = self.share.div_(y, rounding_mode="trunc") # Validate if validate: if not torch.lt(torch.abs(self.get_plain_text() * y - tensor), tolerance).all(): raise ValueError("Final result of division is incorrect.") return self # Otherwise multiply by reciprocal if isinstance(y, float): y = torch.tensor([y], dtype=torch.float, device=self.device) assert is_float_tensor(y), "Unsupported type for div_: %s" % type(y) return self.mul_(y.reciprocal()) def matmul(self, y): """Perform matrix multiplication using some tensor""" return self._arithmetic_function(y, "matmul") def conv1d(self, kernel, **kwargs): """Perform a 1D convolution using the given kernel""" return self._arithmetic_function(kernel, "conv1d", **kwargs) def conv2d(self, kernel, **kwargs): """Perform a 2D convolution using the given kernel""" return self._arithmetic_function(kernel, "conv2d", **kwargs) def conv_transpose1d(self, kernel, **kwargs): """Perform a 1D transpose convolution (deconvolution) using the given kernel""" return self._arithmetic_function(kernel, "conv_transpose1d", **kwargs) def conv_transpose2d(self, kernel, **kwargs): """Perform a 2D transpose convolution (deconvolution) using the given kernel""" return self._arithmetic_function(kernel, "conv_transpose2d", **kwargs) def index_add(self, dim, index, tensor): """Perform out-of-place index_add: Accumulate the elements of tensor into the self tensor by adding to the indices in the order given in index.""" result = self.clone() return result.index_add_(dim, index, tensor) def index_add_(self, dim, index, tensor): """Perform in-place index_add: Accumulate the elements of tensor into the self tensor by adding to the indices in the order given in index.""" public = isinstance(tensor, (int, float)) or is_tensor(tensor) private = isinstance(tensor, ArithmeticSharedTensor) if public: enc_tensor = self.encoder.encode(tensor) if self.rank == 0: self._tensor.index_add_(dim, index, enc_tensor) elif private: self._tensor.index_add_(dim, index, tensor._tensor) else: raise TypeError("index_add second tensor of unsupported type") return self def scatter_add(self, dim, index, other): """Adds all values from the tensor other into self at the indices specified in the index tensor in a similar fashion as scatter_(). For each value in other, it is added to an index in self which is specified by its index in other for dimension != dim and by the corresponding value in index for dimension = dim. """ return self.clone().scatter_add_(dim, index, other) def scatter_add_(self, dim, index, other): """Adds all values from the tensor other into self at the indices specified in the index tensor in a similar fashion as scatter_(). For each value in other, it is added to an index in self which is specified by its index in other for dimension != dim and by the corresponding value in index for dimension = dim. """ public = isinstance(other, (int, float)) or is_tensor(other) private = isinstance(other, ArithmeticSharedTensor) if public: if self.rank == 0: self.share.scatter_add_(dim, index, self.encoder.encode(other)) elif private: self.share.scatter_add_(dim, index, other.share) else: raise TypeError("scatter_add second tensor of unsupported type") return self def avg_pool2d(self, kernel_size, stride=None, padding=0, ceil_mode=False): """Perform an average pooling on each 2D matrix of the given tensor Args: kernel_size (int or tuple): pooling kernel size. """ # TODO: Add check for whether ceil_mode would change size of output and allow ceil_mode when it wouldn't if ceil_mode: raise NotImplementedError( "CrypTen does not support `ceil_mode` for `avg_pool2d`") z = self._sum_pool2d(kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode) if isinstance(kernel_size, (int, float)): pool_size = kernel_size**2 else: pool_size = kernel_size[0] * kernel_size[1] return z / pool_size def _sum_pool2d(self, kernel_size, stride=None, padding=0, ceil_mode=False): """Perform a sum pooling on each 2D matrix of the given tensor""" result = self.shallow_copy() result.share = torch.nn.functional.avg_pool2d( self.share, kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode, divisor_override=1, ) return result # negation and reciprocal: def neg_(self): """Negate the tensor's values""" self.share.neg_() return self def neg(self): """Negate the tensor's values""" return self.clone().neg_() def square_(self): protocol = globals()[cfg.mpc.protocol] self.share = protocol.square(self).div_(self.encoder.scale).share return self def square(self): return self.clone().square_() def where(self, condition, y): """Selects elements from self or y based on condition Args: condition (torch.bool or ArithmeticSharedTensor): when True yield self, otherwise yield y. y (torch.tensor or ArithmeticSharedTensor): values selected at indices where condition is False. Returns: ArithmeticSharedTensor or torch.tensor """ if is_tensor(condition): condition = condition.float() y_masked = y * (1 - condition) else: # encrypted tensor must be first operand y_masked = (1 - condition) * y return self * condition + y_masked def scatter_(self, dim, index, src): """Writes all values from the tensor `src` into `self` at the indices specified in the `index` tensor. For each value in `src`, its output index is specified by its index in `src` for `dimension != dim` and by the corresponding value in `index` for `dimension = dim`. """ if is_tensor(src): src = ArithmeticSharedTensor(src) assert isinstance(src, ArithmeticSharedTensor ), "Unrecognized scatter src type: %s" % type(src) self.share.scatter_(dim, index, src.share) return self def scatter(self, dim, index, src): """Writes all values from the tensor `src` into `self` at the indices specified in the `index` tensor. For each value in `src`, its output index is specified by its index in `src` for `dimension != dim` and by the corresponding value in `index` for `dimension = dim`. """ result = self.clone() return result.scatter_(dim, index, src) # overload operators: __add__ = add __iadd__ = add_ __radd__ = __add__ __sub__ = sub __isub__ = sub_ __mul__ = mul __imul__ = mul_ __rmul__ = __mul__ __div__ = div __truediv__ = div __itruediv__ = div_ __neg__ = neg def __rsub__(self, tensor): """Subtracts self from tensor.""" return -self + tensor @property def data(self): return self._tensor.data @data.setter def data(self, value): self._tensor.set_(value)
class MPCTensor(CrypTensor): def __init__(self, tensor, ptype=Ptype.arithmetic, device=None, *args, **kwargs): """ Creates the shared tensor from the input `tensor` provided by party `src`. The `ptype` defines the type of sharing used (default: arithmetic). The other parties can specify a `tensor` or `size` to determine the size of the shared tensor object to create. In this case, all parties must specify the same (tensor) size to prevent the party's shares from varying in size, which leads to undefined behavior. Alternatively, the parties can set `broadcast_size` to `True` to have the `src` party broadcast the correct size. The parties who do not know the tensor size beforehand can provide an empty tensor as input. This is guaranteed to produce correct behavior but requires an additional communication round. The parties can also set the `precision` and `device` for their share of the tensor. If `device` is unspecified, it is set to `tensor.device`. """ if tensor is None: raise ValueError("Cannot initialize tensor with None.") # take required_grad from kwargs, input tensor, or set to False: default = tensor.requires_grad if torch.is_tensor(tensor) else False requires_grad = kwargs.pop("requires_grad", default) # call CrypTensor constructor: super().__init__(requires_grad=requires_grad) if device is None and hasattr(tensor, "device"): device = tensor.device # create the MPCTensor: tensor_type = ptype.to_tensor() if tensor is []: self._tensor = torch.tensor([], device=device) else: self._tensor = tensor_type(tensor=tensor, device=device, *args, **kwargs) self.ptype = ptype @staticmethod def new(*args, **kwargs): """ Creates a new MPCTensor, passing all args and kwargs into the constructor. """ return MPCTensor(*args, **kwargs) @staticmethod def from_shares(share, precision=None, ptype=Ptype.arithmetic): result = MPCTensor([]) from_shares = ptype.to_tensor().from_shares result._tensor = from_shares(share, precision=precision) result.ptype = ptype return result def clone(self): """Create a deep copy of the input tensor.""" # TODO: Rename this to __deepcopy__()? result = MPCTensor([]) result._tensor = self._tensor.clone() result.ptype = self.ptype return result def shallow_copy(self): """Create a shallow copy of the input tensor.""" # TODO: Rename this to __copy__()? result = MPCTensor([]) result._tensor = self._tensor result.ptype = self.ptype return result def copy_(self, other): """Copies value of other MPCTensor into this MPCTensor.""" assert isinstance(other, MPCTensor), "other must be MPCTensor" self._tensor.copy_(other._tensor) self.ptype = other.ptype def to(self, *args, **kwargs): r""" Depending on the input arguments, converts underlying share to the given ptype or performs `torch.to` on the underlying torch tensor To convert underlying share to the given ptype, call `to` as: to(ptype, **kwargs) It will call MPCTensor.to_ptype with the arguments provided above. Otherwise, `to` performs `torch.to` on the underlying torch tensor. See https://pytorch.org/docs/stable/tensors.html?highlight=#torch.Tensor.to for a reference of the parameters that can be passed in. Args: ptype: Ptype.arithmetic or Ptype.binary. """ if "ptype" in kwargs: return self._to_ptype(**kwargs) elif args and isinstance(args[0], Ptype): ptype = args[0] return self._to_ptype(ptype, **kwargs) else: share = self.share.to(*args, **kwargs) if share.is_cuda: share = CUDALongTensor(share) self.share = share return self def _to_ptype(self, ptype, **kwargs): r""" Convert MPCTensor's underlying share to the corresponding ptype (ArithmeticSharedTensor, BinarySharedTensor) Args: ptype (Ptype.arithmetic or Ptype.binary): The ptype to convert the shares to. precision (int, optional): Precision of the fixed point encoder when converting a binary share to an arithmetic share. It will be ignored if the ptype doesn't match. bits (int, optional): If specified, will only preserve the bottom `bits` bits of a binary tensor when converting from a binary share to an arithmetic share. It will be ignored if the ptype doesn't match. """ retval = self.clone() if retval.ptype == ptype: return retval retval._tensor = convert(self._tensor, ptype, **kwargs) retval.ptype = ptype return retval def arithmetic(self): """Converts self._tensor to arithmetic secret sharing""" return self.to(Ptype.arithmetic) def binary(self): """Converts self._tensor to binary secret sharing""" return self.to(Ptype.binary) @property def device(self): """Return the `torch.device` of the underlying share""" return self.share.device @property def is_cuda(self): """Return True if the underlying share is stored on GPU, False otherwise""" return self.share.is_cuda def cuda(self, *args, **kwargs): """Call `torch.Tensor.cuda` on the underlying share""" self.share = CUDALongTensor(self.share.cuda(*args, **kwargs)) return self def cpu(self): """Call `torch.Tensor.cpu` on the underlying share""" self.share = self.share.cpu() return self def get_plain_text(self, dst=None): """Decrypts the tensor.""" return self._tensor.get_plain_text(dst=dst) def reveal(self, dst=None): """Decrypts the tensor without any downscaling.""" return self._tensor.reveal(dst=dst) def __repr__(self): """Returns a representation of the tensor useful for debugging.""" from crypten.debug import debug_mode share = self.share plain_text = self._tensor.get_plain_text() if debug_mode( ) else "HIDDEN" ptype = self.ptype return (f"MPCTensor(\n\t_tensor={share}\n" f"\tplain_text={plain_text}\n\tptype={ptype}\n)") def __hash__(self): return hash(self.share) @property def share(self): """Returns underlying share""" return self._tensor.share @share.setter def share(self, value): """Sets share to value""" self._tensor.share = value @property def data(self): """Returns share data""" return self.share.data @data.setter def data(self, value): """Sets data to value""" self.share.data = value @property def encoder(self): """Returns underlying encoder""" return self._tensor.encoder @encoder.setter def encoder(self, value): """Sets encoder to value""" self._tensor.encoder = value @staticmethod def __cat_stack_helper(op, tensors, *args, **kwargs): assert op in ["cat", "stack"], "Unsupported op for helper function" assert isinstance(tensors, list), "%s input must be a list" % op assert len(tensors) > 0, "expected a non-empty list of MPCTensors" _ptype = kwargs.pop("ptype", None) # Populate ptype field if _ptype is None: for tensor in tensors: if isinstance(tensor, MPCTensor): _ptype = tensor.ptype break if _ptype is None: _ptype = Ptype.arithmetic # Make all inputs MPCTensors of given ptype for i, tensor in enumerate(tensors): if tensor.ptype != _ptype: tensors[i] = tensor.to(_ptype) # Operate on all input tensors result = tensors[0].clone() funcs = {"cat": torch_cat, "stack": torch_stack} result.share = funcs[op]([tensor.share for tensor in tensors], *args, **kwargs) return result @staticmethod def cat(tensors, *args, **kwargs): """Perform matrix concatenation""" return MPCTensor.__cat_stack_helper("cat", tensors, *args, **kwargs) @staticmethod def stack(tensors, *args, **kwargs): """Perform tensor stacking""" return MPCTensor.__cat_stack_helper("stack", tensors, *args, **kwargs) @staticmethod def rand(*sizes, device=None): """ Returns a tensor with elements uniformly sampled in [0, 1). The uniform random samples are generated by generating random bits using fixed-point encoding and converting the result to an ArithmeticSharedTensor. """ rand = MPCTensor([]) encoder = FixedPointEncoder() rand._tensor = BinarySharedTensor.rand(*sizes, bits=encoder._precision_bits, device=device) rand._tensor.encoder = encoder rand.ptype = Ptype.binary return rand.to(Ptype.arithmetic, bits=encoder._precision_bits) @staticmethod def randn(*sizes, device=None): """ Returns a tensor with normally distributed elements. Samples are generated using the Box-Muller transform with optimizations for numerical precision and MPC efficiency. """ u = MPCTensor.rand(*sizes, device=device).flatten() odd_numel = u.numel() % 2 == 1 if odd_numel: u = MPCTensor.cat([u, MPCTensor.rand((1, ), device=device)]) n = u.numel() // 2 u1 = u[:n] u2 = u[n:] # Radius = sqrt(- 2 * log(u1)) r2 = -2 * u1.log(input_in_01=True) r = r2.sqrt() # Theta = cos(2 * pi * u2) or sin(2 * pi * u2) cos, sin = u2.sub(0.5).mul(6.28318531).cossin() # Generating 2 independent normal random variables using x = r.mul(sin) y = r.mul(cos) z = MPCTensor.cat([x, y]) if odd_numel: z = z[1:] return z.view(*sizes) def bernoulli(self): """Returns a tensor with elements in {0, 1}. The i-th element of the output will be 1 with probability according to the i-th value of the input tensor.""" return self > MPCTensor.rand(self.size(), device=self.device) # Comparators @mode(Ptype.binary) def _ltz(self, _scale=True): """Returns 1 for elements that are < 0 and 0 otherwise""" shift = torch.iinfo(torch.long).bits - 1 result = (self >> shift).to(Ptype.arithmetic, bits=1) if _scale: return result * result.encoder._scale else: result.encoder._scale = 1 return result @mode(Ptype.arithmetic) def ge(self, y, _scale=True): """Returns self >= y""" return 1 - self.lt(y, _scale=_scale) @mode(Ptype.arithmetic) def gt(self, y, _scale=True): """Returns self > y""" return (-self + y)._ltz(_scale=_scale) @mode(Ptype.arithmetic) def le(self, y, _scale=True): """Returns self <= y""" return 1 - self.gt(y, _scale=_scale) @mode(Ptype.arithmetic) def lt(self, y, _scale=True): """Returns self < y""" return (self - y)._ltz(_scale=_scale) @mode(Ptype.arithmetic) def eq(self, y, _scale=True): """Returns self == y""" if comm.get().get_world_size() == 2: return (self - y)._eqz_2PC(_scale=_scale) return 1 - self.ne(y, _scale=_scale) @mode(Ptype.arithmetic) def ne(self, y, _scale=True): """Returns self != y""" if comm.get().get_world_size() == 2: return 1 - self.eq(y, _scale=_scale) difference = self - y difference.share = torch_stack([difference.share, -(difference.share)]) return difference._ltz(_scale=_scale).sum(0) @mode(Ptype.arithmetic) def _eqz_2PC(self, _scale=True): """Returns self == 0""" # Create BinarySharedTensors from shares x0 = MPCTensor(self.share, src=0, ptype=Ptype.binary) x1 = MPCTensor(-self.share, src=1, ptype=Ptype.binary) # Perform equality testing using binary shares x0._tensor = x0._tensor.eq(x1._tensor) x0.encoder = x0.encoder if _scale else self.encoder # Convert to Arithmetic sharing result = x0.to(Ptype.arithmetic, bits=1) if not _scale: result.encoder._scale = 1 return result @mode(Ptype.arithmetic) def sign(self, _scale=True): """Computes the sign value of a tensor (0 is considered positive)""" return 1 - 2 * self._ltz(_scale=_scale) @mode(Ptype.arithmetic) def abs(self): """Computes the absolute value of a tensor""" return self * self.sign(_scale=False) @mode(Ptype.arithmetic) def relu(self): """Compute a Rectified Linear function on the input tensor.""" return self * self.ge(0, _scale=False) @mode(Ptype.arithmetic) def weighted_index(self, dim=None): """ Returns a tensor with entries that are one-hot along dimension `dim`. These one-hot entries are set at random with weights given by the input `self`. Examples:: >>> encrypted_tensor = MPCTensor(torch.tensor([1., 6.])) >>> index = encrypted_tensor.weighted_index().get_plain_text() # With 1 / 7 probability torch.tensor([1., 0.]) # With 6 / 7 probability torch.tensor([0., 1.]) """ if dim is None: return self.flatten().weighted_index(dim=0).view(self.size()) x = self.cumsum(dim) max_weight = x.index_select( dim, torch.tensor(x.size(dim) - 1, device=self.device)) r = MPCTensor.rand(max_weight.size(), device=self.device) * max_weight gt = x.gt(r, _scale=False) shifted = gt.roll(1, dims=dim) shifted.share.index_fill_(dim, torch.tensor(0, device=self.device), 0) return gt - shifted @mode(Ptype.arithmetic) def weighted_sample(self, dim=None): """ Samples a single value across dimension `dim` with weights corresponding to the values in `self` Returns the sample and the one-hot index of the sample. Examples:: >>> encrypted_tensor = MPCTensor(torch.tensor([1., 6.])) >>> index = encrypted_tensor.weighted_sample().get_plain_text() # With 1 / 7 probability (torch.tensor([1., 0.]), torch.tensor([1., 0.])) # With 6 / 7 probability (torch.tensor([0., 6.]), torch.tensor([0., 1.])) """ indices = self.weighted_index(dim) sample = self.mul(indices).sum(dim) return sample, indices # max / min-related functions @mode(Ptype.arithmetic) def argmax(self, dim=None, keepdim=False, one_hot=True): """Returns the indices of the maximum value of all elements in the `input` tensor. """ # TODO: Make dim an arg. if self.dim() == 0: result = (MPCTensor(torch.ones( (), device=self.device)) if one_hot else MPCTensor( torch.zeros((), device=self.device))) return result result = _argmax_helper(self, dim, one_hot, config.max_method, _return_max=False) if not one_hot: result = _one_hot_to_index(result, dim, keepdim, self.device) return result @mode(Ptype.arithmetic) def argmin(self, dim=None, keepdim=False, one_hot=True): """Returns the indices of the minimum value of all elements in the `input` tensor. """ # TODO: Make dim an arg. return (-self).argmax(dim=dim, keepdim=keepdim, one_hot=one_hot) @mode(Ptype.arithmetic) def max(self, dim=None, keepdim=False, one_hot=True): """Returns the maximum value of all elements in the input tensor.""" # TODO: Make dim an arg. method = config.max_method if dim is None: if method in ["log_reduction", "double_log_reduction"]: # max_result can be obtained directly max_result = _max_helper_all_tree_reductions(self, method=method) else: # max_result needs to be obtained through argmax with ConfigManager("max_method", method): argmax_result = self.argmax(one_hot=True) max_result = self.mul(argmax_result).sum() return max_result else: argmax_result, max_result = _argmax_helper(self, dim=dim, one_hot=True, method=method, _return_max=True) if max_result is None: max_result = (self * argmax_result).sum(dim=dim, keepdim=keepdim) if keepdim: max_result = (max_result.unsqueeze(dim) if max_result.dim() < self.dim() else max_result) if one_hot: return max_result, argmax_result else: return ( max_result, _one_hot_to_index(argmax_result, dim, keepdim, self.device), ) @mode(Ptype.arithmetic) def min(self, dim=None, keepdim=False, one_hot=True): """Returns the minimum value of all elements in the input tensor.""" # TODO: Make dim an arg. result = (-self).max(dim=dim, keepdim=keepdim, one_hot=one_hot) if dim is None: return -result else: return -result[0], result[1] @mode(Ptype.arithmetic) def max_pool2d( self, kernel_size, padding=0, stride=None, dilation=1, ceil_mode=False, return_indices=False, ): """Applies a 2D max pooling over an input signal composed of several input planes. """ max_input = self.shallow_copy() max_input.share, output_size = pool2d_reshape( self.share, kernel_size, padding=padding, stride=stride, dilation=dilation, ceil_mode=ceil_mode, # padding with extremely negative values to avoid choosing pads. # The magnitude of this value should not be too large because # multiplication can otherwise fail. pad_value=(-(2**24)), # TODO: Find a better solution for padding with max_pooling ) max_vals, argmax_vals = max_input.max(dim=-1, one_hot=True) max_vals = max_vals.view(output_size) if return_indices: if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) argmax_vals = argmax_vals.view(output_size + kernel_size) return max_vals, argmax_vals return max_vals @mode(Ptype.arithmetic) def _max_pool2d_backward( self, indices, kernel_size, padding=None, stride=None, dilation=1, ceil_mode=False, output_size=None, ): """Implements the backwards for a `max_pool2d` call.""" # Setup padding if padding is None: padding = 0 if isinstance(padding, int): padding = padding, padding assert isinstance(padding, tuple), "padding must be a int, tuple, or None" p0, p1 = padding # Setup stride if stride is None: stride = kernel_size if isinstance(stride, int): stride = stride, stride assert isinstance(stride, tuple), "stride must be a int, tuple, or None" s0, s1 = stride # Setup dilation if isinstance(stride, int): dilation = dilation, dilation assert isinstance(dilation, tuple), "dilation must be a int, tuple, or None" d0, d1 = dilation # Setup kernel_size if isinstance(kernel_size, int): kernel_size = kernel_size, kernel_size assert isinstance(padding, tuple), "padding must be a int or tuple" k0, k1 = kernel_size assert self.dim( ) == 4, "Input to _max_pool2d_backward must have 4 dimensions" assert ( indices.dim() == 6 ), "Indices input for _max_pool2d_backward must have 6 dimensions" # Computes one-hot gradient blocks from each output variable that # has non-zero value corresponding to the argmax of the corresponding # block of the max_pool2d input. kernels = self.view(self.size() + (1, 1)) * indices # Use minimal size if output_size is not specified. if output_size is None: output_size = ( self.size(0), self.size(1), s0 * self.size(2) - 2 * p0, s1 * self.size(3) - 2 * p1, ) # Account for input padding result_size = list(output_size) result_size[-2] += 2 * p0 result_size[-1] += 2 * p1 # Account for input padding implied by ceil_mode if ceil_mode: c0 = self.size(-1) * s1 + (k1 - 1) * d1 - output_size[-1] c1 = self.size(-2) * s0 + (k0 - 1) * d0 - output_size[-2] result_size[-2] += c0 result_size[-1] += c1 # Sum the one-hot gradient blocks at corresponding index locations. result = MPCTensor(torch.zeros(result_size, device=kernels.device)) for i in range(self.size(2)): for j in range(self.size(3)): left_ind = s0 * i top_ind = s1 * j result[:, :, left_ind:left_ind + k0 * d0:d0, top_ind:top_ind + k1 * d1:d1, ] += kernels[:, :, i, j] # Remove input padding if ceil_mode: result = result[:, :, :result.size(2) - c0, :result.size(3) - c1] result = result[:, :, p0:result.size(2) - p0, p1:result.size(3) - p1] return result def adaptive_avg_pool2d(self, output_size): r""" Applies a 2D adaptive average pooling over an input signal composed of several input planes. See :class:`~torch.nn.AdaptiveAvgPool2d` for details and output shape. Args: output_size: the target output size (single integer or double-integer tuple) """ resized_input, args, kwargs = adaptive_pool2d_helper(self, output_size, reduction="mean") return resized_input.avg_pool2d(*args, **kwargs) def adaptive_max_pool2d(self, output_size, return_indices=False): r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes. See :class:`~torch.nn.AdaptiveMaxPool2d` for details and output shape. Args: output_size: the target output size (single integer or double-integer tuple) return_indices: whether to return pooling indices. Default: ``False`` """ resized_input, args, kwargs = adaptive_pool2d_helper(self, output_size, reduction="max") return resized_input.max_pool2d(*args, **kwargs, return_indices=return_indices) def where(self, condition, y): """Selects elements from self or y based on condition Args: condition (torch.bool or MPCTensor): when True yield self, otherwise yield y y (torch.tensor or MPCTensor): values selected at indices where condition is False. Returns: MPCTensor or torch.tensor """ if is_tensor(condition): condition = condition.float() y_masked = y * (1 - condition) else: # encrypted tensor must be first operand y_masked = (1 - condition) * y return self * condition + y_masked @mode(Ptype.arithmetic) def div(self, y): r"""Divides each element of :attr:`self` with the scalar :attr:`y` or each element of the tensor :attr:`y` and returns a new resulting tensor. For `y` a scalar: .. math:: \text{out}_i = \frac{\text{self}_i}{\text{y}} For `y` a tensor: .. math:: \text{out}_i = \frac{\text{self}_i}{\text{y}_i} Note for :attr:`y` a tensor, the shapes of :attr:`self` and :attr:`y` must be `broadcastable`_. .. _broadcastable: https://pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-semantics""" # noqa: B950 result = self.clone() if isinstance(y, CrypTensor): result.share = torch.broadcast_tensors(result.share, y.share)[0].clone() elif is_tensor(y): result.share = torch.broadcast_tensors(result.share, y)[0].clone() if isinstance(y, MPCTensor): return result.mul(y.reciprocal()) result._tensor.div_(y) return result def div_(self, y): """In-place version of :meth:`div`""" if isinstance(y, MPCTensor): return self.mul_(y.reciprocal()) self._tensor.div_(y) return self def index_add(self, dim, index, tensor): """Performs out-of-place index_add: Accumulate the elements of tensor into the self tensor by adding to the indices in the order given in index. """ result = self.clone() assert index.dim() == 1, "index needs to be a vector" public = isinstance(tensor, (int, float)) or is_tensor(tensor) private = isinstance(tensor, MPCTensor) if public: result._tensor.index_add_(dim, index, tensor) elif private: result._tensor.index_add_(dim, index, tensor._tensor) else: raise TypeError("index_add second tensor of unsupported type") return result def scatter_add(self, dim, index, other): """Adds all values from the tensor other into self at the indices specified in the index tensor. """ result = self.clone() public = isinstance(other, (int, float)) or is_tensor(other) private = isinstance(other, CrypTensor) if public: result._tensor.scatter_add_(dim, index, other) elif private: result._tensor.scatter_add_(dim, index, other._tensor) else: raise TypeError("scatter_add second tensor of unsupported type") return result def scatter(self, dim, index, src): """Out-of-place version of :meth:`MPCTensor.scatter_`""" result = self.clone() if is_tensor(src): src = MPCTensor(src) assert isinstance( src, MPCTensor), "Unrecognized scatter src type: %s" % type(src) result.share.scatter_(dim, index, src.share) return result def unbind(self, dim=0): shares = self.share.unbind(dim=dim) results = tuple( MPCTensor(0, ptype=self.ptype, device=self.device) for _ in range(len(shares))) for i in range(len(shares)): results[i].share = shares[i] return results def split(self, split_size, dim=0): shares = self.share.split(split_size, dim=dim) results = tuple( MPCTensor(0, ptype=self.ptype, device=self.device) for _ in range(len(shares))) for i in range(len(shares)): results[i].share = shares[i] return results def set(self, enc_tensor): """ Sets self encrypted to enc_tensor in place by setting shares of self to those of enc_tensor. Args: enc_tensor (MPCTensor): with encrypted shares. """ if is_tensor(enc_tensor): enc_tensor = MPCTensor(enc_tensor) assert isinstance(enc_tensor, MPCTensor), "enc_tensor must be an MPCTensor" self.share.set_(enc_tensor.share) return self