def __and__(self, y): """Bitwise AND operator (element-wise)""" result = self.clone() # TODO: Remove explicit broadcasts to allow smaller beaver triples if isinstance(y, BinarySharedTensor): broadcast_tensors = torch.broadcast_tensors(result.share, y.share) result.share = broadcast_tensors[0].clone() elif is_tensor(y): broadcast_tensors = torch.broadcast_tensors(result.share, y) result.share = broadcast_tensors[0].clone() return result.__iand__(y)
def scatter_(self, dim, index, src): """Writes all values from the tensor `src` into `self` at the indices specified in the `index` tensor. For each value in `src`, its output index is specified by its index in `src` for `dimension != dim` and by the corresponding value in `index` for `dimension = dim`. """ if is_tensor(src): src = ArithmeticSharedTensor(src) assert isinstance(src, ArithmeticSharedTensor ), "Unrecognized scatter src type: %s" % type(src) self.share.scatter_(dim, index, src.share) return self
def scatter_add_(self, dim, index, other): """Adds all values from the tensor other into self at the indices specified in the index tensor.""" public = isinstance(other, (int, float)) or is_tensor(other) private = isinstance(other, CrypTensor) if public: self._tensor.scatter_add_(dim, index, other) elif private: self._tensor.scatter_add_(dim, index, other._tensor) else: raise TypeError("scatter_add second tensor of unsupported type") return self
def stack(tensors, *args, **kwargs): """Perform tensor stacking""" for i, tensor in enumerate(tensors): if is_tensor(tensor): tensors[i] = ArithmeticSharedTensor(tensor) assert isinstance( tensors[i], ArithmeticSharedTensor ), "Can't stack %s with ArithmeticSharedTensor" % type(tensor) result = tensors[0].shallow_copy() result.share = torch_stack([tensor.share for tensor in tensors], *args, **kwargs) return result
def set(self, enc_tensor): """ Sets self encrypted to enc_tensor in place by setting shares of self to those of enc_tensor. Args: enc_tensor (MPCTensor): with encrypted shares. """ if is_tensor(enc_tensor): enc_tensor = MPCTensor(enc_tensor) assert isinstance(enc_tensor, MPCTensor), "enc_tensor must be an MPCTensor" self.share.set_(enc_tensor.share) return self
def index_add_(self, dim, index, tensor): """Perform in-place index_add: Accumulate the elements of tensor into the self tensor by adding to the indices in the order given in index. """ public = isinstance(tensor, (int, float)) or is_tensor(tensor) private = isinstance(tensor, ArithmeticSharedTensor) if public: enc_tensor = self.encoder.encode(tensor) if self.rank == 0: self._tensor.index_add_(dim, index, enc_tensor) elif private: self._tensor.index_add_(dim, index, tensor._tensor) else: raise TypeError("index_add second tensor of unsupported type") return self
def index_add_(self, dim, index, tensor): """Performs in-place index_add: Accumulate the elements of tensor into the self tensor by adding to the indices in the order given in index. """ assert index.dim() == 1, "index needs to be a vector" public = isinstance(tensor, (int, float)) or is_tensor(tensor) private = isinstance(tensor, MPCTensor) if public: self._tensor.index_add_(dim, index, tensor) elif private: self._tensor.index_add_(dim, index, tensor._tensor) else: raise TypeError("index_add second tensor of unsupported type") return self
def scatter_add_(self, dim, index, other): """Adds all values from the tensor other into self at the indices specified in the index tensor in a similar fashion as scatter_(). For each value in other, it is added to an index in self which is specified by its index in other for dimension != dim and by the corresponding value in index for dimension = dim. """ public = isinstance(other, (int, float)) or is_tensor(other) private = isinstance(other, ArithmeticSharedTensor) if public: if self.rank == 0: self.share.scatter_add_(dim, index, self.encoder.encode(other)) elif private: self.share.scatter_add_(dim, index, other.share) else: raise TypeError("scatter_add second tensor of unsupported type") return self
def where(self, condition, y): """Selects elements from self or y based on condition Args: condition (torch.bool or ArithmeticSharedTensor): when True yield self, otherwise yield y. y (torch.tensor or ArithmeticSharedTensor): values selected at indices where condition is False. Returns: ArithmeticSharedTensor or torch.tensor """ if is_tensor(condition): condition = condition.float() y_masked = y * (1 - condition) else: # encrypted tensor must be first operand y_masked = (1 - condition) * y return self * condition + y_masked
def _arithmetic_function(self, y, op, inplace=False, *args, **kwargs): assert op in [ "add", "sub", "mul", "matmul", "conv1d", "conv2d", "conv_transpose1d", "conv_transpose2d", ], f"Provided op `{op}` is not a supported arithmetic function" additive_func = op in ["add", "sub"] public = isinstance(y, (int, float)) or is_tensor(y) private = isinstance(y, ArithmeticSharedTensor) if inplace: result = self if additive_func or (op == "mul" and public): op += "_" else: result = self.clone() if public: y = result.encoder.encode(y, device=self.device) if additive_func: # ['add', 'sub'] if result.rank == 0: result.share = getattr(result.share, op)(y) else: result.share = torch.broadcast_tensors(result.share, y)[0] elif op == "mul_": # ['mul_'] result.share = result.share.mul_(y) else: # ['mul', 'matmul', 'convNd', 'conv_transposeNd'] result.share = getattr(torch, op)(result.share, y, *args, **kwargs) elif private: if additive_func: # ['add', 'sub', 'add_', 'sub_'] result.share = getattr(result.share, op)(y.share) else: # ['mul', 'matmul', 'convNd', 'conv_transposeNd'] # NOTE: 'mul_' calls 'mul' here # Must copy share.data here to support 'mul_' being inplace result.share.set_( getattr(beaver, op)(result, y, *args, **kwargs).share.data) else: raise TypeError("Cannot %s %s with %s" % (op, type(y), type(self))) # Scale by encoder scale if necessary if not additive_func: if public: # scale by self.encoder.scale if self.encoder.scale > 1: return result.div_(result.encoder.scale) else: result.encoder = self.encoder else: # scale by larger of self.encoder.scale and y.encoder.scale if self.encoder.scale > 1 and y.encoder.scale > 1: return result.div_(result.encoder.scale) elif self.encoder.scale > 1: result.encoder = self.encoder else: result.encoder = y.encoder return result